Skip to content

Commit a4e263c

Browse files
custom scopes support in u2m
1 parent 2afaab5 commit a4e263c

File tree

4 files changed

+326
-30
lines changed

4 files changed

+326
-30
lines changed

config/auth_u2m.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@ import (
1212
"golang.org/x/oauth2"
1313
)
1414

15+
// persistentAuthFactory is a function that creates a token source for U2M
16+
// authentication. It can be replaced in tests to spy on the options passed.
17+
type persistentAuthFactory func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error)
18+
1519
// u2mCredentials is a credentials strategy that uses the U2M OAuth flow to
1620
// authenticate with Databricks. It loads a token from the token cache for the
1721
// given workspace or account, refreshing it using the associated refresh token
1822
// if needed.
1923
type u2mCredentials struct {
20-
testTokenSource oauth2.TokenSource // replace u2m token source
24+
// newPersistentAuth is the factory function to create a PersistentAuth.
25+
// If nil, the default u2m.NewPersistentAuth is used.
26+
newPersistentAuth persistentAuthFactory
2127
}
2228

2329
// Name implements CredentialsStrategy.
@@ -38,14 +44,22 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials
3844
return nil, err
3945
}
4046

41-
var ts oauth2.TokenSource
42-
if u.testTokenSource != nil {
43-
ts = u.testTokenSource
44-
} else {
45-
ts, err = u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg), u2m.WithPort(cfg.OAuthCallbackPort))
46-
if err != nil {
47-
return nil, err
47+
var factory persistentAuthFactory
48+
if u.newPersistentAuth == nil {
49+
factory = func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
50+
return u2m.NewPersistentAuth(ctx, opts...)
4851
}
52+
} else {
53+
factory = u.newPersistentAuth
54+
}
55+
ts, err := factory(ctx,
56+
u2m.WithOAuthArgument(arg),
57+
u2m.WithPort(cfg.OAuthCallbackPort),
58+
u2m.WithScopes(cfg.GetScopes()),
59+
u2m.WithDisableOfflineAccess(cfg.DisableOAuthRefreshToken),
60+
)
61+
if err != nil {
62+
return nil, err
4963
}
5064

5165
// TODO: Having to handle the CLI error here is not ideal as it couples the

config/auth_u2m_test.go

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ import (
55
"errors"
66
"fmt"
77
"net/http"
8+
"sort"
89
"strings"
910
"testing"
1011
"time"
1112

1213
"github.com/databricks/databricks-sdk-go/config/credentials"
1314
"github.com/databricks/databricks-sdk-go/credentials/u2m"
15+
"github.com/google/go-cmp/cmp"
1416
"golang.org/x/oauth2"
1517
)
1618

@@ -53,14 +55,37 @@ var (
5355
errInvalidRefreshToken = &u2m.InvalidRefreshTokenError{}
5456
)
5557

58+
// mockPersistentAuthFactory returns a persistentAuthFactory that returns ts.
59+
func mockPersistentAuthFactory(ts oauth2.TokenSource) persistentAuthFactory {
60+
return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
61+
return ts, nil
62+
}
63+
}
64+
65+
// capturingPersistentAuthFactory returns a persistentAuthFactory that applies
66+
// options to a real PersistentAuth and calls onCapture, allowing tests to spy
67+
// on the options passed. It returns ts for token operations.
68+
func capturingPersistentAuthFactory(ts oauth2.TokenSource, onCapture func(*u2m.PersistentAuth)) persistentAuthFactory {
69+
return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
70+
pa, err := u2m.NewPersistentAuth(ctx, opts...)
71+
if err != nil {
72+
return nil, err
73+
}
74+
if onCapture != nil {
75+
onCapture(pa)
76+
}
77+
return ts, nil
78+
}
79+
}
80+
5681
func TestU2MCredentials_Configure(t *testing.T) {
5782
testCases := []struct {
58-
desc string
59-
cfg *Config
60-
testTokenSource *testTokenSource
61-
wantConfigErr string // error message from Configure()
62-
wantHeaderErr string // error message from SetHeaders()
63-
wantAuthHeader string // expected Authorization header
83+
desc string
84+
cfg *Config
85+
tokenSource *testTokenSource
86+
wantConfigErr string // error message from Configure()
87+
wantHeaderErr string // error message from SetHeaders()
88+
wantAuthHeader string // expected Authorization header
6489
}{
6590
{
6691
desc: "missing host returns error",
@@ -74,7 +99,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
7499
cfg: &Config{
75100
Host: "https://workspace.cloud.databricks.com",
76101
},
77-
testTokenSource: &testTokenSource{
102+
tokenSource: &testTokenSource{
78103
token: testValidToken,
79104
},
80105
wantAuthHeader: "Bearer valid-access-token",
@@ -85,7 +110,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
85110
Host: "https://accounts.cloud.databricks.com",
86111
AccountID: "abc-123",
87112
},
88-
testTokenSource: &testTokenSource{
113+
tokenSource: &testTokenSource{
89114
token: testValidToken,
90115
},
91116
wantAuthHeader: "Bearer valid-access-token",
@@ -95,7 +120,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
95120
cfg: &Config{
96121
Host: "https://workspace.cloud.databricks.com",
97122
},
98-
testTokenSource: &testTokenSource{
123+
tokenSource: &testTokenSource{
99124
token: testExpiredToken,
100125
},
101126
wantAuthHeader: "Bearer expired-access-token",
@@ -105,7 +130,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
105130
cfg: &Config{
106131
Host: "https://workspace.cloud.databricks.com",
107132
},
108-
testTokenSource: &testTokenSource{
133+
tokenSource: &testTokenSource{
109134
err: errNetwork,
110135
},
111136
wantHeaderErr: "network timeout",
@@ -115,7 +140,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
115140
cfg: &Config{
116141
Host: "https://workspace.cloud.databricks.com",
117142
},
118-
testTokenSource: &testTokenSource{
143+
tokenSource: &testTokenSource{
119144
err: errAuthentication,
120145
},
121146
wantHeaderErr: "authentication failed",
@@ -127,7 +152,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
127152
Profile: "my-workspace",
128153
resolved: true,
129154
},
130-
testTokenSource: &testTokenSource{
155+
tokenSource: &testTokenSource{
131156
err: errInvalidRefreshToken,
132157
},
133158
wantHeaderErr: "databricks auth login --profile my-workspace",
@@ -138,7 +163,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
138163
Host: "https://workspace.cloud.databricks.com",
139164
resolved: true,
140165
},
141-
testTokenSource: &testTokenSource{
166+
tokenSource: &testTokenSource{
142167
err: errInvalidRefreshToken,
143168
},
144169
wantHeaderErr: "databricks auth login --host https://workspace.cloud.databricks.com",
@@ -151,7 +176,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
151176
Profile: "prod-account",
152177
resolved: true,
153178
},
154-
testTokenSource: &testTokenSource{
179+
tokenSource: &testTokenSource{
155180
err: errInvalidRefreshToken,
156181
},
157182
wantHeaderErr: "databricks auth login --profile prod-account",
@@ -163,7 +188,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
163188
AccountID: "abc-123",
164189
resolved: true,
165190
},
166-
testTokenSource: &testTokenSource{
191+
tokenSource: &testTokenSource{
167192
err: errInvalidRefreshToken,
168193
},
169194
wantHeaderErr: "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc-123",
@@ -175,7 +200,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
175200
Profile: "test",
176201
resolved: true,
177202
},
178-
testTokenSource: &testTokenSource{
203+
tokenSource: &testTokenSource{
179204
err: fmt.Errorf("oauth2: %w", errInvalidRefreshToken),
180205
},
181206
wantHeaderErr: "databricks auth login --profile test",
@@ -187,7 +212,7 @@ func TestU2MCredentials_Configure(t *testing.T) {
187212
AccountID: "abc-456",
188213
resolved: true,
189214
},
190-
testTokenSource: &testTokenSource{
215+
tokenSource: &testTokenSource{
191216
err: errInvalidRefreshToken,
192217
},
193218
wantHeaderErr: "databricks auth login --host https://accounts.azure.databricks.net --account-id abc-456",
@@ -197,7 +222,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
197222
for _, tc := range testCases {
198223
t.Run(tc.desc, func(t *testing.T) {
199224
ctx := context.Background()
200-
u := u2mCredentials{testTokenSource: tc.testTokenSource}
225+
u := u2mCredentials{
226+
newPersistentAuth: mockPersistentAuthFactory(tc.tokenSource),
227+
}
201228

202229
cp, gotConfigErr := u.Configure(ctx, tc.cfg)
203230

@@ -238,7 +265,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
238265
func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
239266
ts := &testTokenSource{token: testValidToken}
240267

241-
u := u2mCredentials{testTokenSource: ts}
268+
u := u2mCredentials{
269+
newPersistentAuth: mockPersistentAuthFactory(ts),
270+
}
242271
cfg := &Config{
243272
Host: "https://workspace.cloud.databricks.com",
244273
}
@@ -261,3 +290,54 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
261290
t.Errorf("token source call count = %d, want 1 (should use cache)", ts.counts)
262291
}
263292
}
293+
294+
func TestU2MCredentials_Configure_Scopes(t *testing.T) {
295+
testCases := []struct {
296+
desc string
297+
configScopes []string
298+
expectedScopes []string
299+
sortScopes bool // whether to sort captured scopes before comparison
300+
}{
301+
{
302+
desc: "default scopes when not specified",
303+
configScopes: nil,
304+
expectedScopes: []string{"all-apis"},
305+
sortScopes: false,
306+
},
307+
{
308+
desc: "custom scopes are passed through",
309+
configScopes: []string{"sql", "clusters"},
310+
expectedScopes: []string{"clusters", "sql"}, // sorted during config resolution
311+
sortScopes: true,
312+
},
313+
}
314+
315+
for _, tc := range testCases {
316+
t.Run(tc.desc, func(t *testing.T) {
317+
ts := &testTokenSource{token: testValidToken}
318+
var capturedScopes []string
319+
320+
u := u2mCredentials{
321+
newPersistentAuth: capturingPersistentAuthFactory(ts, func(pa *u2m.PersistentAuth) {
322+
capturedScopes = pa.GetScopes()
323+
}),
324+
}
325+
cfg := &Config{
326+
Host: "https://workspace.cloud.databricks.com",
327+
Scopes: tc.configScopes,
328+
}
329+
330+
_, err := u.Configure(context.Background(), cfg)
331+
if err != nil {
332+
t.Fatalf("Configure() error = %v", err)
333+
}
334+
335+
if tc.sortScopes {
336+
sort.Strings(capturedScopes)
337+
}
338+
if diff := cmp.Diff(tc.expectedScopes, capturedScopes); diff != "" {
339+
t.Errorf("scopes mismatch (-want +got):\n%s", diff)
340+
}
341+
})
342+
}
343+
}

credentials/u2m/persistent_auth.go

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ type PersistentAuth struct {
8888
// netListen is an optional function to listen on a TCP address. If not set,
8989
// it will use net.Listen by default. This is useful for testing.
9090
netListen func(network, address string) (net.Listener, error)
91+
92+
// scopes is the list of OAuth scopes to request.
93+
scopes []string
94+
95+
// disableOfflineAccess controls whether offline_access scope is requested.
96+
// When true, offline_access will NOT be automatically added to scopes,
97+
// meaning the token will not include a refresh token.
98+
disableOfflineAccess bool
9199
}
92100

93101
type PersistentAuthOption func(*PersistentAuth)
@@ -135,6 +143,26 @@ func WithPort(port int) PersistentAuthOption {
135143
}
136144
}
137145

146+
// WithScopes sets the OAuth scopes for the PersistentAuth.
147+
func WithScopes(scopes []string) PersistentAuthOption {
148+
return func(a *PersistentAuth) {
149+
a.scopes = scopes
150+
}
151+
}
152+
153+
// WithDisableOfflineAccess controls whether offline_access scope is requested.
154+
// When true, offline_access will NOT be automatically added to scopes.
155+
func WithDisableOfflineAccess(disable bool) PersistentAuthOption {
156+
return func(a *PersistentAuth) {
157+
a.disableOfflineAccess = disable
158+
}
159+
}
160+
161+
// GetScopes returns the OAuth scopes configured for this PersistentAuth.
162+
func (a *PersistentAuth) GetScopes() []string {
163+
return a.scopes
164+
}
165+
138166
// NewPersistentAuth creates a new PersistentAuth with the provided options.
139167
func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) {
140168
p := &PersistentAuth{}
@@ -368,10 +396,13 @@ func (a *PersistentAuth) validateArg() error {
368396

369397
// oauth2Config returns the OAuth2 configuration for the given OAuthArgument.
370398
func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
371-
scopes := []string{
372-
"offline_access", // ensures OAuth token includes refresh token
373-
"all-apis", // ensures OAuth token has access to all control-plane APIs
399+
scopes := a.scopes
400+
if !a.disableOfflineAccess {
401+
// Use append to create a new slice with "offline_access" added,
402+
// avoiding mutation of the original a.scopes slice.
403+
scopes = append(append([]string{}, scopes...), "offline_access")
374404
}
405+
375406
var endpoints *OAuthAuthorizationServer
376407
var err error
377408
switch argg := a.oAuthArgument.(type) {

0 commit comments

Comments
 (0)