diff --git a/config/auth_default.go b/config/auth_default.go index 1b34a16fb..8ec3ce119 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -149,6 +149,7 @@ func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsSt if cfg.HostType() != WorkspaceHost { oidcConfig.AccountID = cfg.AccountID } + oidcConfig.SetScopes(cfg.GetScopes()) tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig) return NewTokenSourceStrategy(name, tokenSource) } diff --git a/config/auth_default_test.go b/config/auth_default_test.go index fbcb67116..ede5dfb3a 100644 --- a/config/auth_default_test.go +++ b/config/auth_default_test.go @@ -2,8 +2,13 @@ package config import ( "context" + "encoding/json" + "net/http" + "net/http/httptest" "strings" "testing" + + "github.com/databricks/databricks-sdk-go/credentials/u2m" ) func TestDefaultCredentials_Configure(t *testing.T) { @@ -47,3 +52,101 @@ func TestDefaultCredentials_Configure(t *testing.T) { }) } } + +func TestGithubOIDC_Scopes(t *testing.T) { + const oidcTokenPath = "/oidc/v1/token" + + tests := []struct { + name string + scopes []string + want string + }{ + { + name: "nil scopes uses default", + scopes: nil, + want: "all-apis", + }, + { + name: "empty scopes uses default", + scopes: []string{}, + want: "all-apis", + }, + { + name: "single scope", + scopes: []string{"clusters"}, + want: "clusters", + }, + { + name: "multiple scopes are sorted", + scopes: []string{"jobs", "clusters", "files:read"}, + want: "clusters files:read jobs", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock GitHub server for OIDC token requests. + githubServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"value": "github-id-token"}) + })) + defer githubServer.Close() + + // Mock Databricks server to verify the SDK passes the correct scopes. + var databricksServer *httptest.Server + databricksServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oidc/.well-known/oauth-authorization-server": + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(u2m.OAuthAuthorizationServer{ + AuthorizationEndpoint: "https://host.com/oidc/v1/authorize", + TokenEndpoint: databricksServer.URL + oidcTokenPath, + }) + + case oidcTokenPath: + if err := r.ParseForm(); err != nil { + t.Fatalf("Failed to parse form: %v", err) + } + // The scope assertion: verifies the SDK sends the correct scope parameter. + if got := r.Form.Get("scope"); got != tt.want { + t.Errorf("scope: got %q, want %q", got, tt.want) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "token_type": "Bearer", + "access_token": "databricks-access-token", + "expires_in": 3600, + }) + + default: + t.Errorf("Unexpected request: %s %s", r.Method, r.URL.Path) + http.Error(w, "Not found", http.StatusNotFound) + } + })) + defer databricksServer.Close() + + cfg := &Config{ + Host: databricksServer.URL, + ClientID: "test-client-id", + ActionsIDTokenRequestURL: githubServer.URL + "/github-token?version=1", + ActionsIDTokenRequestToken: "github-request-token", + TokenAudience: "databricks-test-audience", + AuthType: "github-oidc", + Scopes: tt.scopes, + } + + req, err := http.NewRequest("GET", databricksServer.URL+"/api/test", nil) + if err != nil { + t.Fatalf("http.NewRequest(): unexpected error: %v", err) + } + err = cfg.Authenticate(req) + if err != nil { + t.Fatalf("Authenticate(): unexpected error: %v", err) + } + wantAuthHeader := "Bearer databricks-access-token" + if got := req.Header.Get("Authorization"); got != wantAuthHeader { + t.Errorf("Authorization header: got %q, want %q", got, wantAuthHeader) + } + }) + } +} diff --git a/config/experimental/auth/oidc/tokensource.go b/config/experimental/auth/oidc/tokensource.go index 25eb1b6f4..5d5e1a114 100644 --- a/config/experimental/auth/oidc/tokensource.go +++ b/config/experimental/auth/oidc/tokensource.go @@ -39,6 +39,23 @@ type DatabricksOIDCTokenSourceConfig struct { // IDTokenSource returns the IDToken to be used for the token exchange. IDTokenSource IDTokenSource + + // scopes is the list of OAuth scopes to request. + scopes []string +} + +// GetScopes returns the OAuth scopes to request. If no scopes have been set, +// it returns the default scope "all-apis". +func (c *DatabricksOIDCTokenSourceConfig) GetScopes() []string { + if len(c.scopes) == 0 { + return []string{"all-apis"} + } + return c.scopes +} + +// SetScopes sets the OAuth scopes to request. +func (c *DatabricksOIDCTokenSourceConfig) SetScopes(scopes []string) { + c.scopes = scopes } // NewDatabricksOIDCTokenSource returns a new Databricks OIDC TokenSource. @@ -77,11 +94,14 @@ func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, e return nil, err } + // This nil check is to ensure backwards compatibility for users implementing their own + // OIDC token source. + scopes := w.cfg.GetScopes() c := &clientcredentials.Config{ ClientID: w.cfg.ClientID, AuthStyle: oauth2.AuthStyleInParams, TokenURL: endpoints.TokenEndpoint, - Scopes: []string{"all-apis"}, + Scopes: scopes, EndpointParams: url.Values{ "subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"}, "subject_token": {idToken.Value}, diff --git a/config/experimental/auth/oidc/tokensource_test.go b/config/experimental/auth/oidc/tokensource_test.go index 3410978c6..604ea8ca9 100644 --- a/config/experimental/auth/oidc/tokensource_test.go +++ b/config/experimental/auth/oidc/tokensource_test.go @@ -319,3 +319,94 @@ func TestDatabricksOidcTokenSource(t *testing.T) { }) } } + +func TestWIF_Scopes(t *testing.T) { + const ( + testClientID = "test-client-id" + testIDToken = "test-id-token" + testAccessToken = "test-access-token" + testTokenPath = "/oidc/v1/token" + testHost = "https://host.com" + ) + + tests := []struct { + name string + scopes []string + want string + }{ + { + name: "nil scopes uses default", + scopes: nil, + want: "all-apis", + }, + { + name: "empty scopes uses default", + scopes: []string{}, + want: "all-apis", + }, + { + name: "single scope", + scopes: []string{"dashboards"}, + want: "dashboards", + }, + { + name: "multiple scopes", + scopes: []string{"jobs", "files:read", "mlflow"}, + want: "jobs files:read mlflow", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DatabricksOIDCTokenSourceConfig{ + ClientID: testClientID, + Host: testHost, + TokenEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) { + return &u2m.OAuthAuthorizationServer{ + TokenEndpoint: testHost + testTokenPath, + }, nil + }, + Audience: "token-audience", + IDTokenSource: IDTokenSourceFn(func(ctx context.Context, aud string) (*IDToken, error) { + return &IDToken{Value: testIDToken}, nil + }), + scopes: tt.scopes, + } + + ts := NewDatabricksOIDCTokenSource(cfg) + + // The scope assertion: verifies the token source sends the correct scope parameter. + expectedRequest := url.Values{ + "client_id": {testClientID}, + "scope": {tt.want}, + "subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"}, + "subject_token": {testIDToken}, + "grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"}, + } + + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{ + Transport: fixtures.MappingTransport{ + "POST " + testTokenPath: { + Status: http.StatusOK, + ExpectedHeaders: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + ExpectedRequest: expectedRequest, + Response: map[string]string{ + "token_type": "Bearer", + "access_token": testAccessToken, + }, + }, + }, + }) + + token, err := ts.Token(ctx) + if err != nil { + t.Fatalf("Token(ctx): got error %q, want none", err) + } + if token.AccessToken != testAccessToken { + t.Errorf("Token(ctx): got access token %q, want %q", token.AccessToken, testAccessToken) + } + }) + } +}