Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsSt
if cfg.HostType() != WorkspaceHost {
oidcConfig.AccountID = cfg.AccountID
}
oidcConfig.SetScopes(cfg.GetScopes())
tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig)
return NewTokenSourceStrategy(name, tokenSource)
}
Expand Down
103 changes: 103 additions & 0 deletions config/auth_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package config

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/databricks/databricks-sdk-go/credentials/u2m"
)

func TestDefaultCredentials_Configure(t *testing.T) {
Expand Down Expand Up @@ -47,3 +52,101 @@ func TestDefaultCredentials_Configure(t *testing.T) {
})
}
}

func TestGithubOIDC_Scopes(t *testing.T) {
const oidcTokenPath = "/oidc/v1/token"

tests := []struct {
name string
scopes []string
want string
}{
{
name: "nil scopes uses default",
scopes: nil,
want: "all-apis",
},
{
name: "empty scopes uses default",
scopes: []string{},
want: "all-apis",
},
{
name: "single scope",
scopes: []string{"clusters"},
want: "clusters",
},
{
name: "multiple scopes are sorted",
scopes: []string{"jobs", "clusters", "files:read"},
want: "clusters files:read jobs",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock GitHub server for OIDC token requests.
githubServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"value": "github-id-token"})
}))
defer githubServer.Close()

// Mock Databricks server to verify the SDK passes the correct scopes.
var databricksServer *httptest.Server
databricksServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oidc/.well-known/oauth-authorization-server":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(u2m.OAuthAuthorizationServer{
AuthorizationEndpoint: "https://host.com/oidc/v1/authorize",
TokenEndpoint: databricksServer.URL + oidcTokenPath,
})

case oidcTokenPath:
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// The scope assertion: verifies the SDK sends the correct scope parameter.
if got := r.Form.Get("scope"); got != tt.want {
t.Errorf("scope: got %q, want %q", got, tt.want)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"token_type": "Bearer",
"access_token": "databricks-access-token",
"expires_in": 3600,
})

default:
t.Errorf("Unexpected request: %s %s", r.Method, r.URL.Path)
http.Error(w, "Not found", http.StatusNotFound)
}
}))
defer databricksServer.Close()

cfg := &Config{
Host: databricksServer.URL,
ClientID: "test-client-id",
ActionsIDTokenRequestURL: githubServer.URL + "/github-token?version=1",
ActionsIDTokenRequestToken: "github-request-token",
TokenAudience: "databricks-test-audience",
AuthType: "github-oidc",
Scopes: tt.scopes,
}

req, err := http.NewRequest("GET", databricksServer.URL+"/api/test", nil)
if err != nil {
t.Fatalf("http.NewRequest(): unexpected error: %v", err)
}
err = cfg.Authenticate(req)
if err != nil {
t.Fatalf("Authenticate(): unexpected error: %v", err)
}
wantAuthHeader := "Bearer databricks-access-token"
if got := req.Header.Get("Authorization"); got != wantAuthHeader {
t.Errorf("Authorization header: got %q, want %q", got, wantAuthHeader)
}
})
}
}
22 changes: 21 additions & 1 deletion config/experimental/auth/oidc/tokensource.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ type DatabricksOIDCTokenSourceConfig struct {

// IDTokenSource returns the IDToken to be used for the token exchange.
IDTokenSource IDTokenSource

// scopes is the list of OAuth scopes to request.
scopes []string
}

// GetScopes returns the OAuth scopes to request. If no scopes have been set,
// it returns the default scope "all-apis".
func (c *DatabricksOIDCTokenSourceConfig) GetScopes() []string {
if len(c.scopes) == 0 {
return []string{"all-apis"}
}
return c.scopes
}

// SetScopes sets the OAuth scopes to request.
func (c *DatabricksOIDCTokenSourceConfig) SetScopes(scopes []string) {
c.scopes = scopes
}

// NewDatabricksOIDCTokenSource returns a new Databricks OIDC TokenSource.
Expand Down Expand Up @@ -77,11 +94,14 @@ func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, e
return nil, err
}

// This nil check is to ensure backwards compatibility for users implementing their own
// OIDC token source.
scopes := w.cfg.GetScopes()
c := &clientcredentials.Config{
ClientID: w.cfg.ClientID,
AuthStyle: oauth2.AuthStyleInParams,
TokenURL: endpoints.TokenEndpoint,
Scopes: []string{"all-apis"},
Scopes: scopes,
EndpointParams: url.Values{
"subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"},
"subject_token": {idToken.Value},
Expand Down
91 changes: 91 additions & 0 deletions config/experimental/auth/oidc/tokensource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,94 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
})
}
}

func TestWIF_Scopes(t *testing.T) {
const (
testClientID = "test-client-id"
testIDToken = "test-id-token"
testAccessToken = "test-access-token"
testTokenPath = "/oidc/v1/token"
testHost = "https://host.com"
)

tests := []struct {
name string
scopes []string
want string
}{
{
name: "nil scopes uses default",
scopes: nil,
want: "all-apis",
},
{
name: "empty scopes uses default",
scopes: []string{},
want: "all-apis",
},
{
name: "single scope",
scopes: []string{"dashboards"},
want: "dashboards",
},
{
name: "multiple scopes",
scopes: []string{"jobs", "files:read", "mlflow"},
want: "jobs files:read mlflow",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := DatabricksOIDCTokenSourceConfig{
ClientID: testClientID,
Host: testHost,
TokenEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) {
return &u2m.OAuthAuthorizationServer{
TokenEndpoint: testHost + testTokenPath,
}, nil
},
Audience: "token-audience",
IDTokenSource: IDTokenSourceFn(func(ctx context.Context, aud string) (*IDToken, error) {
return &IDToken{Value: testIDToken}, nil
}),
scopes: tt.scopes,
}

ts := NewDatabricksOIDCTokenSource(cfg)

// The scope assertion: verifies the token source sends the correct scope parameter.
expectedRequest := url.Values{
"client_id": {testClientID},
"scope": {tt.want},
"subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"},
"subject_token": {testIDToken},
"grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"},
}

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{
Transport: fixtures.MappingTransport{
"POST " + testTokenPath: {
Status: http.StatusOK,
ExpectedHeaders: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
ExpectedRequest: expectedRequest,
Response: map[string]string{
"token_type": "Bearer",
"access_token": testAccessToken,
},
},
},
})

token, err := ts.Token(ctx)
if err != nil {
t.Fatalf("Token(ctx): got error %q, want none", err)
}
if token.AccessToken != testAccessToken {
t.Errorf("Token(ctx): got access token %q, want %q", token.AccessToken, testAccessToken)
}
})
}
}
Loading