Skip to content

Commit dab64d8

Browse files
improve u2m flow testing by refactoring
1 parent c3c3c71 commit dab64d8

File tree

4 files changed

+147
-78
lines changed

4 files changed

+147
-78
lines changed

config/auth_u2m.go

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@ import (
1212
"golang.org/x/oauth2"
1313
)
1414

15+
// PersistentAuthFactory is a function that creates a token source for U2M
16+
// authentication. It can be replaced in tests to spy on the options passed.
17+
type PersistentAuthFactory func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error)
18+
1519
// u2mCredentials is a credentials strategy that uses the U2M OAuth flow to
1620
// authenticate with Databricks. It loads a token from the token cache for the
1721
// given workspace or account, refreshing it using the associated refresh token
1822
// if needed.
1923
type u2mCredentials struct {
20-
testTokenSource oauth2.TokenSource // replace u2m token source
24+
// newPersistentAuth is the factory function to create a PersistentAuth.
25+
// If nil, the default u2m.NewPersistentAuth is used.
26+
newPersistentAuth PersistentAuthFactory
2127
}
2228

2329
// Name implements CredentialsStrategy.
@@ -38,20 +44,21 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials
3844
return nil, err
3945
}
4046

41-
var ts oauth2.TokenSource
42-
if u.testTokenSource != nil {
43-
ts = u.testTokenSource
44-
} else {
45-
ts, err = u2m.NewPersistentAuth(ctx,
46-
u2m.WithOAuthArgument(arg),
47-
u2m.WithPort(cfg.OAuthCallbackPort),
48-
u2m.WithScopes(cfg.GetScopes()),
49-
u2m.WithDisableOfflineAccess(cfg.DisableOAuthRefreshToken),
50-
)
51-
if err != nil {
52-
return nil, err
47+
factory := u.newPersistentAuth
48+
if factory == nil {
49+
factory = func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
50+
return u2m.NewPersistentAuth(ctx, opts...)
5351
}
5452
}
53+
ts, err := factory(ctx,
54+
u2m.WithOAuthArgument(arg),
55+
u2m.WithPort(cfg.OAuthCallbackPort),
56+
u2m.WithScopes(cfg.GetScopes()),
57+
u2m.WithDisableOfflineAccess(cfg.DisableOAuthRefreshToken),
58+
)
59+
if err != nil {
60+
return nil, err
61+
}
5562

5663
// TODO: Having to handle the CLI error here is not ideal as it couples the
5764
// SDK logic with the CLI's. Remove this wrapping logic as soon as the CLI

config/auth_u2m_test.go

Lines changed: 111 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"net/http"
8+
"sort"
89
"strings"
910
"testing"
1011
"time"
@@ -53,14 +54,50 @@ var (
5354
errInvalidRefreshToken = &u2m.InvalidRefreshTokenError{}
5455
)
5556

57+
// mockPersistentAuthFactory returns a PersistentAuthFactory that returns ts.
58+
func mockPersistentAuthFactory(ts oauth2.TokenSource) PersistentAuthFactory {
59+
return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
60+
return ts, nil
61+
}
62+
}
63+
64+
// capturingPersistentAuthFactory returns a PersistentAuthFactory that applies
65+
// options to a real PersistentAuth and calls onCapture, allowing tests to spy
66+
// on the options passed. It returns ts for token operations.
67+
func capturingPersistentAuthFactory(ts oauth2.TokenSource, onCapture func(*u2m.PersistentAuth)) PersistentAuthFactory {
68+
return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
69+
pa, err := u2m.NewPersistentAuth(ctx, opts...)
70+
if err != nil {
71+
return nil, err
72+
}
73+
if onCapture != nil {
74+
onCapture(pa)
75+
}
76+
return ts, nil
77+
}
78+
}
79+
80+
// equalStringSlices compares two string slices for equality.
81+
func equalStringSlices(a, b []string) bool {
82+
if len(a) != len(b) {
83+
return false
84+
}
85+
for i := range a {
86+
if a[i] != b[i] {
87+
return false
88+
}
89+
}
90+
return true
91+
}
92+
5693
func TestU2MCredentials_Configure(t *testing.T) {
5794
testCases := []struct {
58-
desc string
59-
cfg *Config
60-
testTokenSource *testTokenSource
61-
wantConfigErr string // error message from Configure()
62-
wantHeaderErr string // error message from SetHeaders()
63-
wantAuthHeader string // expected Authorization header
95+
desc string
96+
cfg *Config
97+
tokenSource *testTokenSource
98+
wantConfigErr string // error message from Configure()
99+
wantHeaderErr string // error message from SetHeaders()
100+
wantAuthHeader string // expected Authorization header
64101
}{
65102
{
66103
desc: "missing host returns error",
@@ -74,7 +111,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
74111
cfg: &Config{
75112
Host: "https://workspace.cloud.databricks.com",
76113
},
77-
testTokenSource: &testTokenSource{
114+
tokenSource: &testTokenSource{
78115
token: testValidToken,
79116
},
80117
wantAuthHeader: "Bearer valid-access-token",
@@ -85,7 +122,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
85122
Host: "https://accounts.cloud.databricks.com",
86123
AccountID: "abc-123",
87124
},
88-
testTokenSource: &testTokenSource{
125+
tokenSource: &testTokenSource{
89126
token: testValidToken,
90127
},
91128
wantAuthHeader: "Bearer valid-access-token",
@@ -95,7 +132,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
95132
cfg: &Config{
96133
Host: "https://workspace.cloud.databricks.com",
97134
},
98-
testTokenSource: &testTokenSource{
135+
tokenSource: &testTokenSource{
99136
token: testExpiredToken,
100137
},
101138
wantAuthHeader: "Bearer expired-access-token",
@@ -105,7 +142,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
105142
cfg: &Config{
106143
Host: "https://workspace.cloud.databricks.com",
107144
},
108-
testTokenSource: &testTokenSource{
145+
tokenSource: &testTokenSource{
109146
err: errNetwork,
110147
},
111148
wantHeaderErr: "network timeout",
@@ -115,7 +152,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
115152
cfg: &Config{
116153
Host: "https://workspace.cloud.databricks.com",
117154
},
118-
testTokenSource: &testTokenSource{
155+
tokenSource: &testTokenSource{
119156
err: errAuthentication,
120157
},
121158
wantHeaderErr: "authentication failed",
@@ -127,7 +164,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
127164
Profile: "my-workspace",
128165
resolved: true,
129166
},
130-
testTokenSource: &testTokenSource{
167+
tokenSource: &testTokenSource{
131168
err: errInvalidRefreshToken,
132169
},
133170
wantHeaderErr: "databricks auth login --profile my-workspace",
@@ -138,7 +175,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
138175
Host: "https://workspace.cloud.databricks.com",
139176
resolved: true,
140177
},
141-
testTokenSource: &testTokenSource{
178+
tokenSource: &testTokenSource{
142179
err: errInvalidRefreshToken,
143180
},
144181
wantHeaderErr: "databricks auth login --host https://workspace.cloud.databricks.com",
@@ -151,7 +188,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
151188
Profile: "prod-account",
152189
resolved: true,
153190
},
154-
testTokenSource: &testTokenSource{
191+
tokenSource: &testTokenSource{
155192
err: errInvalidRefreshToken,
156193
},
157194
wantHeaderErr: "databricks auth login --profile prod-account",
@@ -163,7 +200,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
163200
AccountID: "abc-123",
164201
resolved: true,
165202
},
166-
testTokenSource: &testTokenSource{
203+
tokenSource: &testTokenSource{
167204
err: errInvalidRefreshToken,
168205
},
169206
wantHeaderErr: "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc-123",
@@ -175,7 +212,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
175212
Profile: "test",
176213
resolved: true,
177214
},
178-
testTokenSource: &testTokenSource{
215+
tokenSource: &testTokenSource{
179216
err: fmt.Errorf("oauth2: %w", errInvalidRefreshToken),
180217
},
181218
wantHeaderErr: "databricks auth login --profile test",
@@ -187,7 +224,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
187224
AccountID: "abc-456",
188225
resolved: true,
189226
},
190-
testTokenSource: &testTokenSource{
227+
tokenSource: &testTokenSource{
191228
err: errInvalidRefreshToken,
192229
},
193230
wantHeaderErr: "databricks auth login --host https://accounts.azure.databricks.net --account-id abc-456",
@@ -197,7 +234,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
197234
for _, tc := range testCases {
198235
t.Run(tc.desc, func(t *testing.T) {
199236
ctx := context.Background()
200-
u := u2mCredentials{testTokenSource: tc.testTokenSource}
237+
u := u2mCredentials{
238+
newPersistentAuth: mockPersistentAuthFactory(tc.tokenSource),
239+
}
201240

202241
cp, gotConfigErr := u.Configure(ctx, tc.cfg)
203242

@@ -238,7 +277,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
238277
func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
239278
ts := &testTokenSource{token: testValidToken}
240279

241-
u := u2mCredentials{testTokenSource: ts}
280+
u := u2mCredentials{
281+
newPersistentAuth: mockPersistentAuthFactory(ts),
282+
}
242283
cfg := &Config{
243284
Host: "https://workspace.cloud.databricks.com",
244285
}
@@ -261,3 +302,54 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
261302
t.Errorf("token source call count = %d, want 1 (should use cache)", ts.counts)
262303
}
263304
}
305+
306+
func TestU2MCredentials_Configure_DefaultScopes(t *testing.T) {
307+
ts := &testTokenSource{token: testValidToken}
308+
var capturedScopes []string
309+
310+
u := u2mCredentials{
311+
newPersistentAuth: capturingPersistentAuthFactory(ts, func(pa *u2m.PersistentAuth) {
312+
capturedScopes = pa.GetScopes()
313+
}),
314+
}
315+
cfg := &Config{
316+
Host: "https://workspace.cloud.databricks.com",
317+
}
318+
319+
_, err := u.Configure(context.Background(), cfg)
320+
if err != nil {
321+
t.Fatalf("Configure() error = %v", err)
322+
}
323+
324+
expectedScopes := []string{"all-apis"}
325+
if !equalStringSlices(capturedScopes, expectedScopes) {
326+
t.Errorf("scopes = %v, want %v", capturedScopes, expectedScopes)
327+
}
328+
}
329+
330+
func TestU2MCredentials_Configure_CustomScopes(t *testing.T) {
331+
ts := &testTokenSource{token: testValidToken}
332+
var capturedScopes []string
333+
334+
u := u2mCredentials{
335+
newPersistentAuth: capturingPersistentAuthFactory(ts, func(pa *u2m.PersistentAuth) {
336+
capturedScopes = pa.GetScopes()
337+
}),
338+
}
339+
cfg := &Config{
340+
Host: "https://workspace.cloud.databricks.com",
341+
Scopes: []string{"sql", "clusters"},
342+
}
343+
344+
_, err := u.Configure(context.Background(), cfg)
345+
if err != nil {
346+
t.Fatalf("Configure() error = %v", err)
347+
}
348+
349+
// Scopes are sorted during config resolution.
350+
expectedScopes := []string{"clusters", "sql"}
351+
sort.Strings(capturedScopes)
352+
if !equalStringSlices(capturedScopes, expectedScopes) {
353+
t.Errorf("scopes = %v, want %v", capturedScopes, expectedScopes)
354+
}
355+
}

credentials/u2m/persistent_auth.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ func WithDisableOfflineAccess(disable bool) PersistentAuthOption {
158158
}
159159
}
160160

161+
// GetScopes returns the OAuth scopes configured for this PersistentAuth.
162+
func (a *PersistentAuth) GetScopes() []string {
163+
return a.scopes
164+
}
165+
161166
// NewPersistentAuth creates a new PersistentAuth with the provided options.
162167
func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) {
163168
p := &PersistentAuth{}

0 commit comments

Comments
 (0)