Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions config/auth_u2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ import (
"golang.org/x/oauth2"
)

// persistentAuthFactory is a function that creates a token source for U2M
// authentication. It can be replaced in tests to spy on the options passed.
type persistentAuthFactory func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error)

// u2mCredentials is a credentials strategy that uses the U2M OAuth flow to
// authenticate with Databricks. It loads a token from the token cache for the
// given workspace or account, refreshing it using the associated refresh token
// if needed.
type u2mCredentials struct {
testTokenSource oauth2.TokenSource // replace u2m token source
// newPersistentAuth is the factory function to create a PersistentAuth.
// If nil, the default u2m.NewPersistentAuth is used.
newPersistentAuth persistentAuthFactory
}

// Name implements CredentialsStrategy.
Expand All @@ -38,14 +44,22 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials
return nil, err
}

var ts oauth2.TokenSource
if u.testTokenSource != nil {
ts = u.testTokenSource
} else {
ts, err = u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg), u2m.WithPort(cfg.OAuthCallbackPort))
if err != nil {
return nil, err
var factory persistentAuthFactory
if u.newPersistentAuth == nil {
factory = func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
return u2m.NewPersistentAuth(ctx, opts...)
}
} else {
factory = u.newPersistentAuth
}
ts, err := factory(ctx,
u2m.WithOAuthArgument(arg),
u2m.WithPort(cfg.OAuthCallbackPort),
u2m.WithScopes(cfg.GetScopes()),
u2m.WithDisableOfflineAccess(cfg.DisableOAuthRefreshToken),
)
if err != nil {
return nil, err
}

// TODO: Having to handle the CLI error here is not ideal as it couples the
Expand Down
91 changes: 89 additions & 2 deletions config/auth_u2m_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (

"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/oauth2"
)

Expand Down Expand Up @@ -53,6 +55,14 @@ var (
errInvalidRefreshToken = &u2m.InvalidRefreshTokenError{}
)

// mockPersistentAuthFactory creates a test factory for bypassing real auth setup.
// Use this when tests only need to control token behavior without caring about auth configuration.
func mockPersistentAuthFactory(ts oauth2.TokenSource) persistentAuthFactory {
return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
return ts, nil
}
}

func TestU2MCredentials_Configure(t *testing.T) {
testCases := []struct {
desc string
Expand Down Expand Up @@ -197,7 +207,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.Background()
u := u2mCredentials{testTokenSource: tc.testTokenSource}
u := u2mCredentials{
newPersistentAuth: mockPersistentAuthFactory(tc.testTokenSource),
}

cp, gotConfigErr := u.Configure(ctx, tc.cfg)

Expand Down Expand Up @@ -238,7 +250,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
ts := &testTokenSource{token: testValidToken}

u := u2mCredentials{testTokenSource: ts}
u := u2mCredentials{
newPersistentAuth: mockPersistentAuthFactory(ts),
}
cfg := &Config{
Host: "https://workspace.cloud.databricks.com",
}
Expand All @@ -261,3 +275,76 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
t.Errorf("token source call count = %d, want 1 (should use cache)", ts.counts)
}
}

func TestU2MCredentials_Configure_Scopes(t *testing.T) {
const workspaceHost = "https://workspace.cloud.databricks.com"
testCases := []struct {
name string
scopes []string
want []string
}{
{
name: "nil scopes uses default",
scopes: nil,
want: []string{"all-apis"},
},
{
name: "empty scopes uses default",
scopes: []string{},
want: []string{"all-apis"},
},
{
name: "multiple scopes are sorted",
scopes: []string{"clusters", "jobs", "sql:read"},
want: []string{"clusters", "jobs", "sql:read"},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := &testTokenSource{token: testValidToken}
var capturedPA *u2m.PersistentAuth

u := u2mCredentials{
newPersistentAuth: func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
pa, err := u2m.NewPersistentAuth(ctx, opts...)
if err != nil {
return nil, err
}
capturedPA = pa
return ts, nil
},
}
cfg := &Config{
Host: workspaceHost,
Scopes: tc.scopes,
}

_, err := u.Configure(context.Background(), cfg)
if err != nil {
t.Fatalf("Configure() error = %v", err)
}

arg, err := u2m.NewBasicWorkspaceOAuthArgument(workspaceHost)
if err != nil {
t.Fatalf("NewBasicWorkspaceOAuthArgument() error = %v", err)
}
wantPA, err := u2m.NewPersistentAuth(context.Background(),
u2m.WithOAuthArgument(arg),
u2m.WithScopes(tc.want),
)
if err != nil {
t.Fatalf("NewPersistentAuth() error = %v", err)
}

if diff := cmp.Diff(wantPA, capturedPA,
cmp.AllowUnexported(u2m.PersistentAuth{}, u2m.BasicWorkspaceOAuthArgument{}),
cmpopts.IgnoreFields(u2m.PersistentAuth{},
"cache", "client", "endpointSupplier", "browser",
"ln", "ctx", "redirectAddr", "port", "netListen", "disableOfflineAccess"),
); diff != "" {
t.Errorf("PersistentAuth mismatch (-want +got):\n%s", diff)
}
})
}
}
34 changes: 31 additions & 3 deletions credentials/u2m/persistent_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ type PersistentAuth struct {
// netListen is an optional function to listen on a TCP address. If not set,
// it will use net.Listen by default. This is useful for testing.
netListen func(network, address string) (net.Listener, error)

// scopes is the list of OAuth scopes to request.
scopes []string

// disableOfflineAccess controls whether offline_access scope is requested.
// When true, offline_access will NOT be automatically added to scopes,
// meaning the token will not include a refresh token.
disableOfflineAccess bool
}

type PersistentAuthOption func(*PersistentAuth)
Expand Down Expand Up @@ -135,6 +143,20 @@ func WithPort(port int) PersistentAuthOption {
}
}

// WithScopes sets the OAuth scopes for the PersistentAuth.
func WithScopes(scopes []string) PersistentAuthOption {
return func(a *PersistentAuth) {
a.scopes = scopes
}
}

// WithDisableOfflineAccess controls whether offline_access scope is requested.
func WithDisableOfflineAccess(disable bool) PersistentAuthOption {
return func(a *PersistentAuth) {
a.disableOfflineAccess = disable
}
}

// NewPersistentAuth creates a new PersistentAuth with the provided options.
func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) {
p := &PersistentAuth{}
Expand Down Expand Up @@ -368,10 +390,16 @@ func (a *PersistentAuth) validateArg() error {

// oauth2Config returns the OAuth2 configuration for the given OAuthArgument.
func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
scopes := []string{
"offline_access", // ensures OAuth token includes refresh token
"all-apis", // ensures OAuth token has access to all control-plane APIs
// Default to "all-apis" for backwards compatibility with direct users of PersistentAuth
// i.e. people implementing their own U2M authentication.
scopes := a.scopes
if len(scopes) == 0 {
scopes = []string{"all-apis"}
}
if !a.disableOfflineAccess {
scopes = append([]string{"offline_access"}, scopes...)
}

var endpoints *OAuthAuthorizationServer
var err error
switch argg := a.oAuthArgument.(type) {
Expand Down
153 changes: 153 additions & 0 deletions credentials/u2m/persistent_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,156 @@ func TestPersistentAuth_startListener_explicitPortNoFallBack(t *testing.T) {
t.Fatalf("pa.startListener(): want error %v, got %v", testError, gotErr)
}
}

// TestU2M_ScopesAndOfflineAccess verifies that OAuth scopes are correctly configured
// and sent during the authorization flow, and that the disableOfflineAccess flag
// correctly controls whether offline_access is added to the scope.
func TestU2M_ScopesAndOfflineAccess(t *testing.T) {
const (
testWorkspaceHost = "https://workspace.cloud.databricks.com"
testTokenEndpoint = "/oidc/v1/token"
testCallbackURL = "http://localhost:8020"
)

tests := []struct {
name string
scopes []string
disableOffline bool
want string
}{
{
name: "nil scopes uses default with offline_access",
scopes: nil,
disableOffline: false,
want: "offline_access all-apis",
},
{
name: "empty scopes uses default with offline_access",
scopes: []string{},
disableOffline: false,
want: "offline_access all-apis",
},
{
name: "single scope with offline_access",
scopes: []string{"dashboards"},
disableOffline: false,
want: "offline_access dashboards",
},
{
name: "multiple scopes with offline_access",
scopes: []string{"files", "jobs", "mlflow:read"},
disableOffline: false,
want: "offline_access files jobs mlflow:read",
},
{
name: "disable offline_access",
scopes: []string{"files", "jobs"},
disableOffline: true,
want: "files jobs",
},
{
name: "nil scopes with disable offline_access",
scopes: nil,
disableOffline: true,
want: "all-apis",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()

var scopeReceived, stateReceived string
browserCalled := make(chan struct{})
defer close(browserCalled)
browser := func(redirect string) error {
u, err := url.ParseRequestURI(redirect)
if err != nil {
return err
}
query := u.Query()
scopeReceived = query.Get("scope")
stateReceived = query.Get("state")
browserCalled <- struct{}{}
return nil
}

cache := &tokenCacheMock{
store: func(key string, tok *oauth2.Token) error {
return nil
},
}

arg, err := NewBasicWorkspaceOAuthArgument(testWorkspaceHost)
if err != nil {
t.Fatalf("NewBasicWorkspaceOAuthArgument(): want no error, got %v", err)
}

var tokenResponse string
if tt.disableOffline {
tokenResponse = `access_token=token`
} else {
tokenResponse = `access_token=token&refresh_token=refresh`
}

opts := []PersistentAuthOption{
WithTokenCache(cache),
WithBrowser(browser),
WithHttpClient(&http.Client{
Transport: fixtures.SliceTransport{
{
Method: "POST",
Resource: testTokenEndpoint,
Response: tokenResponse,
ResponseHeaders: map[string][]string{
"Content-Type": {"application/x-www-form-urlencoded"},
},
},
},
}),
WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}),
WithOAuthArgument(arg),
WithDisableOfflineAccess(tt.disableOffline),
WithScopes(tt.scopes),
}

p, err := NewPersistentAuth(ctx, opts...)
if err != nil {
t.Fatalf("NewPersistentAuth(): want no error, got %v", err)
}
defer p.Close()

errc := make(chan error)
defer close(errc)
go func() {
err := p.Challenge()
errc <- err
}()

select {
case <-browserCalled:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for browser to be called")
}

if scopeReceived != tt.want {
t.Errorf("scope: want %q, got %q", tt.want, scopeReceived)
}

resp, err := http.Get(fmt.Sprintf("%s?code=__CODE__&state=%s", testCallbackURL, stateReceived))
if err != nil {
t.Fatalf("http.Get(): want no error, got %v", err)
}
defer resp.Body.Close()

select {
case err = <-errc:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for Challenge() to complete")
}
if err != nil {
t.Fatalf("p.Challenge(): want no error, got %v", err)
}
})
}
}
Loading