@@ -14,12 +14,12 @@ import (
1414 "time"
1515 "unicode/utf8"
1616
17- sdkconfig "github.com/databricks/databricks-sdk-go/config"
18- "github.com/databricks/databricks-sdk-go/service/iam"
19-
2017 "github.com/databricks/cli/libs/env"
18+ "github.com/databricks/cli/libs/testproxy"
2119 "github.com/databricks/cli/libs/testserver"
2220 "github.com/databricks/databricks-sdk-go"
21+ sdkconfig "github.com/databricks/databricks-sdk-go/config"
22+ "github.com/databricks/databricks-sdk-go/service/iam"
2323 "github.com/google/uuid"
2424 "github.com/stretchr/testify/assert"
2525 "github.com/stretchr/testify/require"
@@ -40,81 +40,115 @@ func isTruePtr(value *bool) bool {
4040 return value != nil && * value
4141}
4242
43- func PrepareServerAndClient (t * testing.T , config TestConfig , logRequests bool , outputDir string , mu * sync. Mutex ) (* sdkconfig.Config , iam.User ) {
43+ func PrepareServerAndClient (t * testing.T , config TestConfig , logRequests bool , outputDir string ) (* sdkconfig.Config , iam.User ) {
4444 cloudEnv := os .Getenv ("CLOUD_ENV" )
45+ recordRequests := isTruePtr (config .RecordRequests )
4546
46- // If we are running on a cloud environment, use the host configured in the
47- // environment.
4847 if cloudEnv != "" {
4948 w , err := databricks .NewWorkspaceClient ()
5049 require .NoError (t , err )
5150
5251 user , err := w .CurrentUser .Me (context .Background ())
5352 require .NoError (t , err , "Failed to get current user" )
5453
55- return w .Config , * user
56- }
54+ cfg := w .Config
5755
58- recordRequests := isTruePtr (config .RecordRequests )
56+ // If we are running in a cloud environment AND we are recording requests,
57+ // start a dedicated server to act as a reverse proxy to a real Databricks workspace.
58+ if recordRequests {
59+ host , token := startProxyServer (t , logRequests , config .IncludeRequestHeaders , outputDir )
60+ cfg = & sdkconfig.Config {
61+ Host : host ,
62+ Token : token ,
63+ }
64+ }
5965
60- tokenSuffix := strings . ReplaceAll ( uuid . NewString (), "-" , "" )
61- token := "dbapi" + tokenSuffix
66+ return cfg , * user
67+ }
6268
63- // If we are not recording requests, and no custom server server stubs are configured,
69+ // If we are not recording requests, and no custom server stubs are configured,
6470 // use the default shared server.
6571 if len (config .Server ) == 0 && ! recordRequests {
66- return & sdkconfig.Config {
72+ // Use a unique token for each test. This allows us to maintain
73+ // separate state for each test in fake workspaces.
74+ tokenSuffix := strings .ReplaceAll (uuid .NewString (), "-" , "" )
75+ token := "dbapi" + tokenSuffix
76+
77+ cfg := & sdkconfig.Config {
6778 Host : os .Getenv ("DATABRICKS_DEFAULT_HOST" ),
6879 Token : token ,
69- }, TestUser
70- }
80+ }
7181
72- host := startDedicatedServer (t , config .Server , recordRequests , logRequests , config .IncludeRequestHeaders , outputDir , mu )
82+ return cfg , TestUser
83+ }
7384
74- return & sdkconfig.Config {
85+ // Default case. Start a dedicated local server for the test with the server stubs configured
86+ // as overrides.
87+ host , token := startLocalServer (t , config .Server , recordRequests , logRequests , config .IncludeRequestHeaders , outputDir )
88+ cfg := & sdkconfig.Config {
7589 Host : host ,
7690 Token : token ,
77- }, TestUser
91+ }
92+
93+ // For the purposes of replacements, use testUser for local runs.
94+ // Note, users might have overriden /api/2.0/preview/scim/v2/Me but that should not affect the replacement:
95+ return cfg , TestUser
7896}
7997
80- func startDedicatedServer (t * testing.T ,
98+ func recordRequestsCallback (t * testing.T , includeHeaders []string , outputDir string ) func (request * testserver.Request ) {
99+ mu := sync.Mutex {}
100+
101+ return func (request * testserver.Request ) {
102+ mu .Lock ()
103+ defer mu .Unlock ()
104+
105+ req := getLoggedRequest (request , includeHeaders )
106+ reqJson , err := json .MarshalIndent (req , "" , " " )
107+ assert .NoErrorf (t , err , "Failed to json-encode: %#v" , req )
108+
109+ requestsPath := filepath .Join (outputDir , "out.requests.txt" )
110+ f , err := os .OpenFile (requestsPath , os .O_CREATE | os .O_APPEND | os .O_WRONLY , 0o644 )
111+ assert .NoError (t , err )
112+ defer f .Close ()
113+
114+ _ , err = f .WriteString (string (reqJson ) + "\n " )
115+ assert .NoError (t , err )
116+ }
117+ }
118+
119+ func logResponseCallback (t * testing.T ) func (request * testserver.Request , response * testserver.EncodedResponse ) {
120+ mu := sync.Mutex {}
121+
122+ return func (request * testserver.Request , response * testserver.EncodedResponse ) {
123+ mu .Lock ()
124+ defer mu .Unlock ()
125+
126+ t .Logf ("%d %s %s\n %s\n %s" ,
127+ response .StatusCode , request .Method , request .URL ,
128+ formatHeadersAndBody ("> " , request .Headers , request .Body ),
129+ formatHeadersAndBody ("# " , response .Headers , response .Body ),
130+ )
131+ }
132+ }
133+
134+ func startLocalServer (t * testing.T ,
81135 stubs []ServerStub ,
82136 recordRequests bool ,
83137 logRequests bool ,
84138 includeHeaders []string ,
85139 outputDir string ,
86- mu * sync.Mutex ,
87- ) string {
140+ ) (string , string ) {
88141 s := testserver .New (t )
89142
143+ // Record API requests in out.requests.txt if RecordRequests is true
144+ // in test.toml
90145 if recordRequests {
91- requestsPath := filepath .Join (outputDir , "out.requests.txt" )
92- s .RequestCallback = func (request * testserver.Request ) {
93- req := getLoggedRequest (request , includeHeaders )
94- reqJson , err := json .MarshalIndent (req , "" , " " )
95-
96- mu .Lock ()
97- defer mu .Unlock ()
98-
99- assert .NoErrorf (t , err , "Failed to json-encode: %#v" , req )
100-
101- f , err := os .OpenFile (requestsPath , os .O_CREATE | os .O_APPEND | os .O_WRONLY , 0o644 )
102- assert .NoError (t , err )
103- defer f .Close ()
104-
105- _ , err = f .WriteString (string (reqJson ) + "\n " )
106- assert .NoError (t , err )
107- }
146+ s .RequestCallback = recordRequestsCallback (t , includeHeaders , outputDir )
108147 }
109148
149+ // Log API responses if the -logrequests flag is set.
110150 if logRequests {
111- s .ResponseCallback = func (request * testserver.Request , response * testserver.EncodedResponse ) {
112- t .Logf ("%d %s %s\n %s\n %s" ,
113- response .StatusCode , request .Method , request .URL ,
114- formatHeadersAndBody ("> " , request .Headers , request .Body ),
115- formatHeadersAndBody ("# " , response .Headers , response .Body ),
116- )
117- }
151+ s .ResponseCallback = logResponseCallback (t )
118152 }
119153
120154 for ind := range stubs {
@@ -132,8 +166,25 @@ func startDedicatedServer(t *testing.T,
132166
133167 // The earliest handlers take precedence, add default handlers last
134168 addDefaultHandlers (s )
169+ return s .URL , "dbapi123"
170+ }
171+
172+ func startProxyServer (t * testing.T ,
173+ logRequests bool ,
174+ includeHeaders []string ,
175+ outputDir string ,
176+ ) (string , string ) {
177+ s := testproxy .New (t )
178+
179+ // Always record requests for a proxy server.
180+ s .RequestCallback = recordRequestsCallback (t , includeHeaders , outputDir )
181+
182+ // Log API responses if the -logrequests flag is set.
183+ if logRequests {
184+ s .ResponseCallback = logResponseCallback (t )
185+ }
135186
136- return s .URL
187+ return s .URL , "dbapi1234"
137188}
138189
139190type LoggedRequest struct {
0 commit comments