@@ -5,12 +5,14 @@ import (
55 "errors"
66 "fmt"
77 "net/http"
8+ "sort"
89 "strings"
910 "testing"
1011 "time"
1112
1213 "github.com/databricks/databricks-sdk-go/config/credentials"
1314 "github.com/databricks/databricks-sdk-go/credentials/u2m"
15+ "github.com/google/go-cmp/cmp"
1416 "golang.org/x/oauth2"
1517)
1618
@@ -53,14 +55,37 @@ var (
5355 errInvalidRefreshToken = & u2m.InvalidRefreshTokenError {}
5456)
5557
58+ // mockPersistentAuthFactory returns a persistentAuthFactory that returns ts.
59+ func mockPersistentAuthFactory (ts oauth2.TokenSource ) persistentAuthFactory {
60+ return func (ctx context.Context , opts ... u2m.PersistentAuthOption ) (oauth2.TokenSource , error ) {
61+ return ts , nil
62+ }
63+ }
64+
65+ // capturingPersistentAuthFactory returns a persistentAuthFactory that applies
66+ // options to a real PersistentAuth and calls onCapture, allowing tests to spy
67+ // on the options passed. It returns ts for token operations.
68+ func capturingPersistentAuthFactory (ts oauth2.TokenSource , onCapture func (* u2m.PersistentAuth )) persistentAuthFactory {
69+ return func (ctx context.Context , opts ... u2m.PersistentAuthOption ) (oauth2.TokenSource , error ) {
70+ pa , err := u2m .NewPersistentAuth (ctx , opts ... )
71+ if err != nil {
72+ return nil , err
73+ }
74+ if onCapture != nil {
75+ onCapture (pa )
76+ }
77+ return ts , nil
78+ }
79+ }
80+
5681func TestU2MCredentials_Configure (t * testing.T ) {
5782 testCases := []struct {
58- desc string
59- cfg * Config
60- testTokenSource * testTokenSource
61- wantConfigErr string // error message from Configure()
62- wantHeaderErr string // error message from SetHeaders()
63- wantAuthHeader string // expected Authorization header
83+ desc string
84+ cfg * Config
85+ tokenSource * testTokenSource
86+ wantConfigErr string // error message from Configure()
87+ wantHeaderErr string // error message from SetHeaders()
88+ wantAuthHeader string // expected Authorization header
6489 }{
6590 {
6691 desc : "missing host returns error" ,
@@ -74,7 +99,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
7499 cfg : & Config {
75100 Host : "https://workspace.cloud.databricks.com" ,
76101 },
77- testTokenSource : & testTokenSource {
102+ tokenSource : & testTokenSource {
78103 token : testValidToken ,
79104 },
80105 wantAuthHeader : "Bearer valid-access-token" ,
@@ -85,7 +110,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
85110 Host : "https://accounts.cloud.databricks.com" ,
86111 AccountID : "abc-123" ,
87112 },
88- testTokenSource : & testTokenSource {
113+ tokenSource : & testTokenSource {
89114 token : testValidToken ,
90115 },
91116 wantAuthHeader : "Bearer valid-access-token" ,
@@ -95,7 +120,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
95120 cfg : & Config {
96121 Host : "https://workspace.cloud.databricks.com" ,
97122 },
98- testTokenSource : & testTokenSource {
123+ tokenSource : & testTokenSource {
99124 token : testExpiredToken ,
100125 },
101126 wantAuthHeader : "Bearer expired-access-token" ,
@@ -105,7 +130,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
105130 cfg : & Config {
106131 Host : "https://workspace.cloud.databricks.com" ,
107132 },
108- testTokenSource : & testTokenSource {
133+ tokenSource : & testTokenSource {
109134 err : errNetwork ,
110135 },
111136 wantHeaderErr : "network timeout" ,
@@ -115,7 +140,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
115140 cfg : & Config {
116141 Host : "https://workspace.cloud.databricks.com" ,
117142 },
118- testTokenSource : & testTokenSource {
143+ tokenSource : & testTokenSource {
119144 err : errAuthentication ,
120145 },
121146 wantHeaderErr : "authentication failed" ,
@@ -127,7 +152,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
127152 Profile : "my-workspace" ,
128153 resolved : true ,
129154 },
130- testTokenSource : & testTokenSource {
155+ tokenSource : & testTokenSource {
131156 err : errInvalidRefreshToken ,
132157 },
133158 wantHeaderErr : "databricks auth login --profile my-workspace" ,
@@ -138,7 +163,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
138163 Host : "https://workspace.cloud.databricks.com" ,
139164 resolved : true ,
140165 },
141- testTokenSource : & testTokenSource {
166+ tokenSource : & testTokenSource {
142167 err : errInvalidRefreshToken ,
143168 },
144169 wantHeaderErr : "databricks auth login --host https://workspace.cloud.databricks.com" ,
@@ -151,7 +176,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
151176 Profile : "prod-account" ,
152177 resolved : true ,
153178 },
154- testTokenSource : & testTokenSource {
179+ tokenSource : & testTokenSource {
155180 err : errInvalidRefreshToken ,
156181 },
157182 wantHeaderErr : "databricks auth login --profile prod-account" ,
@@ -163,7 +188,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
163188 AccountID : "abc-123" ,
164189 resolved : true ,
165190 },
166- testTokenSource : & testTokenSource {
191+ tokenSource : & testTokenSource {
167192 err : errInvalidRefreshToken ,
168193 },
169194 wantHeaderErr : "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc-123" ,
@@ -175,7 +200,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
175200 Profile : "test" ,
176201 resolved : true ,
177202 },
178- testTokenSource : & testTokenSource {
203+ tokenSource : & testTokenSource {
179204 err : fmt .Errorf ("oauth2: %w" , errInvalidRefreshToken ),
180205 },
181206 wantHeaderErr : "databricks auth login --profile test" ,
@@ -187,7 +212,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
187212 AccountID : "abc-456" ,
188213 resolved : true ,
189214 },
190- testTokenSource : & testTokenSource {
215+ tokenSource : & testTokenSource {
191216 err : errInvalidRefreshToken ,
192217 },
193218 wantHeaderErr : "databricks auth login --host https://accounts.azure.databricks.net --account-id abc-456" ,
@@ -197,7 +222,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
197222 for _ , tc := range testCases {
198223 t .Run (tc .desc , func (t * testing.T ) {
199224 ctx := context .Background ()
200- u := u2mCredentials {testTokenSource : tc .testTokenSource }
225+ u := u2mCredentials {
226+ newPersistentAuth : mockPersistentAuthFactory (tc .tokenSource ),
227+ }
201228
202229 cp , gotConfigErr := u .Configure (ctx , tc .cfg )
203230
@@ -238,7 +265,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
238265func TestU2MCredentials_Configure_TokenCaching (t * testing.T ) {
239266 ts := & testTokenSource {token : testValidToken }
240267
241- u := u2mCredentials {testTokenSource : ts }
268+ u := u2mCredentials {
269+ newPersistentAuth : mockPersistentAuthFactory (ts ),
270+ }
242271 cfg := & Config {
243272 Host : "https://workspace.cloud.databricks.com" ,
244273 }
@@ -261,3 +290,54 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
261290 t .Errorf ("token source call count = %d, want 1 (should use cache)" , ts .counts )
262291 }
263292}
293+
294+ func TestU2MCredentials_Configure_Scopes (t * testing.T ) {
295+ testCases := []struct {
296+ desc string
297+ configScopes []string
298+ expectedScopes []string
299+ sortScopes bool // whether to sort captured scopes before comparison
300+ }{
301+ {
302+ desc : "default scopes when not specified" ,
303+ configScopes : nil ,
304+ expectedScopes : []string {"all-apis" },
305+ sortScopes : false ,
306+ },
307+ {
308+ desc : "custom scopes are passed through" ,
309+ configScopes : []string {"sql" , "clusters" },
310+ expectedScopes : []string {"clusters" , "sql" }, // sorted during config resolution
311+ sortScopes : true ,
312+ },
313+ }
314+
315+ for _ , tc := range testCases {
316+ t .Run (tc .desc , func (t * testing.T ) {
317+ ts := & testTokenSource {token : testValidToken }
318+ var capturedScopes []string
319+
320+ u := u2mCredentials {
321+ newPersistentAuth : capturingPersistentAuthFactory (ts , func (pa * u2m.PersistentAuth ) {
322+ capturedScopes = pa .GetScopes ()
323+ }),
324+ }
325+ cfg := & Config {
326+ Host : "https://workspace.cloud.databricks.com" ,
327+ Scopes : tc .configScopes ,
328+ }
329+
330+ _ , err := u .Configure (context .Background (), cfg )
331+ if err != nil {
332+ t .Fatalf ("Configure() error = %v" , err )
333+ }
334+
335+ if tc .sortScopes {
336+ sort .Strings (capturedScopes )
337+ }
338+ if diff := cmp .Diff (tc .expectedScopes , capturedScopes ); diff != "" {
339+ t .Errorf ("scopes mismatch (-want +got):\n %s" , diff )
340+ }
341+ })
342+ }
343+ }
0 commit comments