55 "errors"
66 "fmt"
77 "net/http"
8+ "sort"
89 "strings"
910 "testing"
1011 "time"
@@ -53,14 +54,50 @@ var (
5354 errInvalidRefreshToken = & u2m.InvalidRefreshTokenError {}
5455)
5556
57+ // mockPersistentAuthFactory returns a PersistentAuthFactory that returns ts.
58+ func mockPersistentAuthFactory (ts oauth2.TokenSource ) PersistentAuthFactory {
59+ return func (ctx context.Context , opts ... u2m.PersistentAuthOption ) (oauth2.TokenSource , error ) {
60+ return ts , nil
61+ }
62+ }
63+
64+ // capturingPersistentAuthFactory returns a PersistentAuthFactory that applies
65+ // options to a real PersistentAuth and calls onCapture, allowing tests to spy
66+ // on the options passed. It returns ts for token operations.
67+ func capturingPersistentAuthFactory (ts oauth2.TokenSource , onCapture func (* u2m.PersistentAuth )) PersistentAuthFactory {
68+ return func (ctx context.Context , opts ... u2m.PersistentAuthOption ) (oauth2.TokenSource , error ) {
69+ pa , err := u2m .NewPersistentAuth (ctx , opts ... )
70+ if err != nil {
71+ return nil , err
72+ }
73+ if onCapture != nil {
74+ onCapture (pa )
75+ }
76+ return ts , nil
77+ }
78+ }
79+
80+ // equalStringSlices compares two string slices for equality.
81+ func equalStringSlices (a , b []string ) bool {
82+ if len (a ) != len (b ) {
83+ return false
84+ }
85+ for i := range a {
86+ if a [i ] != b [i ] {
87+ return false
88+ }
89+ }
90+ return true
91+ }
92+
5693func TestU2MCredentials_Configure (t * testing.T ) {
5794 testCases := []struct {
58- desc string
59- cfg * Config
60- testTokenSource * testTokenSource
61- wantConfigErr string // error message from Configure()
62- wantHeaderErr string // error message from SetHeaders()
63- wantAuthHeader string // expected Authorization header
95+ desc string
96+ cfg * Config
97+ tokenSource * testTokenSource
98+ wantConfigErr string // error message from Configure()
99+ wantHeaderErr string // error message from SetHeaders()
100+ wantAuthHeader string // expected Authorization header
64101 }{
65102 {
66103 desc : "missing host returns error" ,
@@ -74,7 +111,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
74111 cfg : & Config {
75112 Host : "https://workspace.cloud.databricks.com" ,
76113 },
77- testTokenSource : & testTokenSource {
114+ tokenSource : & testTokenSource {
78115 token : testValidToken ,
79116 },
80117 wantAuthHeader : "Bearer valid-access-token" ,
@@ -85,7 +122,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
85122 Host : "https://accounts.cloud.databricks.com" ,
86123 AccountID : "abc-123" ,
87124 },
88- testTokenSource : & testTokenSource {
125+ tokenSource : & testTokenSource {
89126 token : testValidToken ,
90127 },
91128 wantAuthHeader : "Bearer valid-access-token" ,
@@ -95,7 +132,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
95132 cfg : & Config {
96133 Host : "https://workspace.cloud.databricks.com" ,
97134 },
98- testTokenSource : & testTokenSource {
135+ tokenSource : & testTokenSource {
99136 token : testExpiredToken ,
100137 },
101138 wantAuthHeader : "Bearer expired-access-token" ,
@@ -105,7 +142,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
105142 cfg : & Config {
106143 Host : "https://workspace.cloud.databricks.com" ,
107144 },
108- testTokenSource : & testTokenSource {
145+ tokenSource : & testTokenSource {
109146 err : errNetwork ,
110147 },
111148 wantHeaderErr : "network timeout" ,
@@ -115,7 +152,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
115152 cfg : & Config {
116153 Host : "https://workspace.cloud.databricks.com" ,
117154 },
118- testTokenSource : & testTokenSource {
155+ tokenSource : & testTokenSource {
119156 err : errAuthentication ,
120157 },
121158 wantHeaderErr : "authentication failed" ,
@@ -127,7 +164,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
127164 Profile : "my-workspace" ,
128165 resolved : true ,
129166 },
130- testTokenSource : & testTokenSource {
167+ tokenSource : & testTokenSource {
131168 err : errInvalidRefreshToken ,
132169 },
133170 wantHeaderErr : "databricks auth login --profile my-workspace" ,
@@ -138,7 +175,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
138175 Host : "https://workspace.cloud.databricks.com" ,
139176 resolved : true ,
140177 },
141- testTokenSource : & testTokenSource {
178+ tokenSource : & testTokenSource {
142179 err : errInvalidRefreshToken ,
143180 },
144181 wantHeaderErr : "databricks auth login --host https://workspace.cloud.databricks.com" ,
@@ -151,7 +188,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
151188 Profile : "prod-account" ,
152189 resolved : true ,
153190 },
154- testTokenSource : & testTokenSource {
191+ tokenSource : & testTokenSource {
155192 err : errInvalidRefreshToken ,
156193 },
157194 wantHeaderErr : "databricks auth login --profile prod-account" ,
@@ -163,7 +200,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
163200 AccountID : "abc-123" ,
164201 resolved : true ,
165202 },
166- testTokenSource : & testTokenSource {
203+ tokenSource : & testTokenSource {
167204 err : errInvalidRefreshToken ,
168205 },
169206 wantHeaderErr : "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc-123" ,
@@ -175,7 +212,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
175212 Profile : "test" ,
176213 resolved : true ,
177214 },
178- testTokenSource : & testTokenSource {
215+ tokenSource : & testTokenSource {
179216 err : fmt .Errorf ("oauth2: %w" , errInvalidRefreshToken ),
180217 },
181218 wantHeaderErr : "databricks auth login --profile test" ,
@@ -187,7 +224,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
187224 AccountID : "abc-456" ,
188225 resolved : true ,
189226 },
190- testTokenSource : & testTokenSource {
227+ tokenSource : & testTokenSource {
191228 err : errInvalidRefreshToken ,
192229 },
193230 wantHeaderErr : "databricks auth login --host https://accounts.azure.databricks.net --account-id abc-456" ,
@@ -197,7 +234,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
197234 for _ , tc := range testCases {
198235 t .Run (tc .desc , func (t * testing.T ) {
199236 ctx := context .Background ()
200- u := u2mCredentials {testTokenSource : tc .testTokenSource }
237+ u := u2mCredentials {
238+ newPersistentAuth : mockPersistentAuthFactory (tc .tokenSource ),
239+ }
201240
202241 cp , gotConfigErr := u .Configure (ctx , tc .cfg )
203242
@@ -238,7 +277,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
238277func TestU2MCredentials_Configure_TokenCaching (t * testing.T ) {
239278 ts := & testTokenSource {token : testValidToken }
240279
241- u := u2mCredentials {testTokenSource : ts }
280+ u := u2mCredentials {
281+ newPersistentAuth : mockPersistentAuthFactory (ts ),
282+ }
242283 cfg := & Config {
243284 Host : "https://workspace.cloud.databricks.com" ,
244285 }
@@ -261,3 +302,54 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
261302 t .Errorf ("token source call count = %d, want 1 (should use cache)" , ts .counts )
262303 }
263304}
305+
306+ func TestU2MCredentials_Configure_DefaultScopes (t * testing.T ) {
307+ ts := & testTokenSource {token : testValidToken }
308+ var capturedScopes []string
309+
310+ u := u2mCredentials {
311+ newPersistentAuth : capturingPersistentAuthFactory (ts , func (pa * u2m.PersistentAuth ) {
312+ capturedScopes = pa .GetScopes ()
313+ }),
314+ }
315+ cfg := & Config {
316+ Host : "https://workspace.cloud.databricks.com" ,
317+ }
318+
319+ _ , err := u .Configure (context .Background (), cfg )
320+ if err != nil {
321+ t .Fatalf ("Configure() error = %v" , err )
322+ }
323+
324+ expectedScopes := []string {"all-apis" }
325+ if ! equalStringSlices (capturedScopes , expectedScopes ) {
326+ t .Errorf ("scopes = %v, want %v" , capturedScopes , expectedScopes )
327+ }
328+ }
329+
330+ func TestU2MCredentials_Configure_CustomScopes (t * testing.T ) {
331+ ts := & testTokenSource {token : testValidToken }
332+ var capturedScopes []string
333+
334+ u := u2mCredentials {
335+ newPersistentAuth : capturingPersistentAuthFactory (ts , func (pa * u2m.PersistentAuth ) {
336+ capturedScopes = pa .GetScopes ()
337+ }),
338+ }
339+ cfg := & Config {
340+ Host : "https://workspace.cloud.databricks.com" ,
341+ Scopes : []string {"sql" , "clusters" },
342+ }
343+
344+ _ , err := u .Configure (context .Background (), cfg )
345+ if err != nil {
346+ t .Fatalf ("Configure() error = %v" , err )
347+ }
348+
349+ // Scopes are sorted during config resolution.
350+ expectedScopes := []string {"clusters" , "sql" }
351+ sort .Strings (capturedScopes )
352+ if ! equalStringSlices (capturedScopes , expectedScopes ) {
353+ t .Errorf ("scopes = %v, want %v" , capturedScopes , expectedScopes )
354+ }
355+ }
0 commit comments