Skip to content

Commit a4831d7

Browse files
custom scopes support in wif
1 parent fbf6a6a commit a4831d7

File tree

4 files changed

+239
-1
lines changed

4 files changed

+239
-1
lines changed

config/auth_default.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsSt
145145
TokenEndpointProvider: cfg.getOidcEndpoints,
146146
Audience: cfg.TokenAudience,
147147
IDTokenSource: ts,
148+
Scopes: cfg.GetScopes(),
148149
}
149150
if cfg.HostType() != WorkspaceHost {
150151
oidcConfig.AccountID = cfg.AccountID

config/auth_default_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@ package config
22

33
import (
44
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
58
"strings"
69
"testing"
10+
11+
"github.com/databricks/databricks-sdk-go/credentials/u2m"
712
)
813

914
func TestDefaultCredentials_Configure(t *testing.T) {
@@ -47,3 +52,100 @@ func TestDefaultCredentials_Configure(t *testing.T) {
4752
})
4853
}
4954
}
55+
56+
func TestGithubOIDC_Scopes(t *testing.T) {
57+
tests := []struct {
58+
name string
59+
scopes []string
60+
expectedScope string
61+
}{
62+
{
63+
name: "default scopes",
64+
scopes: nil,
65+
expectedScope: "all-apis",
66+
},
67+
{
68+
name: "custom scopes",
69+
scopes: []string{"clusters", "jobs"},
70+
expectedScope: "clusters jobs",
71+
},
72+
}
73+
74+
for _, tt := range tests {
75+
t.Run(tt.name, func(t *testing.T) {
76+
githubTokenCalled := false
77+
tokenExchangeCalled := false
78+
79+
// Simulates the GitHub Actions OIDC token endpoint.
80+
githubServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81+
githubTokenCalled = true
82+
w.Header().Set("Content-Type", "application/json")
83+
json.NewEncoder(w).Encode(map[string]string{"value": "github-id-token"})
84+
}))
85+
defer githubServer.Close()
86+
87+
// Simulates a Databricks workspace.
88+
// Asserts whether the right scopes are passed to the token exchange endpoint.
89+
var databricksServer *httptest.Server
90+
databricksServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91+
switch r.URL.Path {
92+
case "/oidc/.well-known/oauth-authorization-server":
93+
w.Header().Set("Content-Type", "application/json")
94+
json.NewEncoder(w).Encode(u2m.OAuthAuthorizationServer{
95+
AuthorizationEndpoint: "https://host.com/oidc/v1/authorize",
96+
TokenEndpoint: databricksServer.URL + "/oidc/v1/token",
97+
})
98+
99+
case "/oidc/v1/token":
100+
tokenExchangeCalled = true
101+
if err := r.ParseForm(); err != nil {
102+
t.Fatalf("Failed to parse form: %v", err)
103+
}
104+
// Verify scope is passed correctly to token exchange.
105+
if got := r.Form.Get("scope"); got != tt.expectedScope {
106+
t.Errorf("scope: got %q, want %q", got, tt.expectedScope)
107+
}
108+
w.Header().Set("Content-Type", "application/json")
109+
json.NewEncoder(w).Encode(map[string]interface{}{
110+
"token_type": "Bearer",
111+
"access_token": "databricks-access-token",
112+
"expires_in": 3600,
113+
})
114+
115+
default:
116+
t.Errorf("Unexpected request: %s %s", r.Method, r.URL.Path)
117+
http.Error(w, "Not found", http.StatusNotFound)
118+
}
119+
}))
120+
defer databricksServer.Close()
121+
122+
cfg := &Config{
123+
Host: databricksServer.URL,
124+
ClientID: "test-client-id",
125+
ActionsIDTokenRequestURL: githubServer.URL + "/github-token?version=1",
126+
ActionsIDTokenRequestToken: "github-request-token",
127+
TokenAudience: "databricks-test-audience",
128+
AuthType: "github-oidc",
129+
}
130+
if tt.scopes != nil {
131+
cfg.Scopes = tt.scopes
132+
}
133+
134+
req, _ := http.NewRequest("GET", databricksServer.URL+"/api/test", nil)
135+
err := cfg.Authenticate(req)
136+
if err != nil {
137+
t.Fatalf("Authenticate(): got error %v, want none", err)
138+
}
139+
140+
if got := req.Header.Get("Authorization"); got != "Bearer databricks-access-token" {
141+
t.Errorf("Authorization header: got %q, want %q", got, "Bearer databricks-access-token")
142+
}
143+
if !githubTokenCalled {
144+
t.Error("GitHub token endpoint was not called")
145+
}
146+
if !tokenExchangeCalled {
147+
t.Error("Token exchange endpoint was not called")
148+
}
149+
})
150+
}
151+
}

config/experimental/auth/oidc/tokensource.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ type DatabricksOIDCTokenSourceConfig struct {
3939

4040
// IDTokenSource returns the IDToken to be used for the token exchange.
4141
IDTokenSource IDTokenSource
42+
43+
// Scopes is the list of OAuth scopes to request.
44+
Scopes []string
4245
}
4346

4447
// NewDatabricksOIDCTokenSource returns a new Databricks OIDC TokenSource.
@@ -81,7 +84,7 @@ func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, e
8184
ClientID: w.cfg.ClientID,
8285
AuthStyle: oauth2.AuthStyleInParams,
8386
TokenURL: endpoints.TokenEndpoint,
84-
Scopes: []string{"all-apis"},
87+
Scopes: w.cfg.Scopes,
8588
EndpointParams: url.Values{
8689
"subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"},
8790
"subject_token": {idToken.Value},

config/experimental/auth/oidc/tokensource_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
275275
ClientID: tc.clientID,
276276
AccountID: tc.accountID,
277277
Host: tc.host,
278+
Scopes: []string{"all-apis"},
278279
TokenEndpointProvider: tc.oidcEndpointProvider,
279280
Audience: tc.tokenAudience,
280281
IDTokenSource: IDTokenSourceFn(func(ctx context.Context, aud string) (*IDToken, error) {
@@ -319,3 +320,134 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
319320
})
320321
}
321322
}
323+
324+
func TestWIF_Scopes(t *testing.T) {
325+
tests := []struct {
326+
name string
327+
clientID string
328+
accountID string
329+
host string
330+
audience string
331+
scopes []string
332+
tokenEndpoint string
333+
expectedClientID string
334+
expectedScope string
335+
expectedAccessToken string
336+
}{
337+
{
338+
name: "single scope",
339+
clientID: "client-id",
340+
host: "http://host.com",
341+
audience: "token-audience",
342+
scopes: []string{"dashboards"},
343+
tokenEndpoint: "https://host.com/oidc/v1/token",
344+
expectedClientID: "client-id",
345+
expectedScope: "dashboards",
346+
expectedAccessToken: "test-token",
347+
},
348+
{
349+
name: "multiple scopes sorted",
350+
clientID: "client-id",
351+
host: "http://host.com",
352+
audience: "token-audience",
353+
scopes: []string{"files", "jobs", "mlflow"},
354+
tokenEndpoint: "https://host.com/oidc/v1/token",
355+
expectedClientID: "client-id",
356+
expectedScope: "files jobs mlflow",
357+
expectedAccessToken: "test-token",
358+
},
359+
{
360+
name: "workspace-level WIF",
361+
clientID: "client-id",
362+
host: "https://my-workspace.cloud.databricks.com",
363+
audience: "workspace-audience",
364+
scopes: []string{"genie"},
365+
tokenEndpoint: "https://my-workspace.cloud.databricks.com/oidc/v1/token",
366+
expectedClientID: "client-id",
367+
expectedScope: "genie",
368+
expectedAccessToken: "workspace-token",
369+
},
370+
{
371+
name: "account-level WIF",
372+
clientID: "client-id",
373+
accountID: "my-account",
374+
host: "https://accounts.cloud.databricks.com",
375+
audience: "account-audience",
376+
scopes: []string{"files", "iam"},
377+
tokenEndpoint: "https://accounts.cloud.databricks.com/oidc/accounts/my-account/v1/token",
378+
expectedClientID: "client-id",
379+
expectedScope: "files iam",
380+
expectedAccessToken: "account-token",
381+
},
382+
{
383+
name: "account-wide token federation (no ClientID)",
384+
clientID: "",
385+
host: "http://host.com",
386+
audience: "token-audience",
387+
scopes: []string{"workspaces"},
388+
tokenEndpoint: "https://host.com/oidc/v1/token",
389+
expectedClientID: "",
390+
expectedScope: "workspaces",
391+
expectedAccessToken: "account-wide-token",
392+
},
393+
}
394+
395+
for _, tt := range tests {
396+
t.Run(tt.name, func(t *testing.T) {
397+
cfg := DatabricksOIDCTokenSourceConfig{
398+
ClientID: tt.clientID,
399+
AccountID: tt.accountID,
400+
Host: tt.host,
401+
TokenEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) {
402+
return &u2m.OAuthAuthorizationServer{
403+
TokenEndpoint: tt.tokenEndpoint,
404+
}, nil
405+
},
406+
Audience: tt.audience,
407+
IDTokenSource: IDTokenSourceFn(func(ctx context.Context, aud string) (*IDToken, error) {
408+
return &IDToken{Value: "id-token"}, nil
409+
}),
410+
Scopes: tt.scopes,
411+
}
412+
413+
ts := NewDatabricksOIDCTokenSource(cfg)
414+
415+
expectedRequest := url.Values{
416+
"scope": {tt.expectedScope},
417+
"subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"},
418+
"subject_token": {"id-token"},
419+
"grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"},
420+
}
421+
if tt.expectedClientID != "" {
422+
expectedRequest["client_id"] = []string{tt.expectedClientID}
423+
}
424+
425+
endpointURL, _ := url.Parse(tt.tokenEndpoint)
426+
endpointPath := "POST " + endpointURL.Path
427+
428+
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{
429+
Transport: fixtures.MappingTransport{
430+
endpointPath: {
431+
Status: http.StatusOK,
432+
ExpectedHeaders: map[string]string{
433+
"Content-Type": "application/x-www-form-urlencoded",
434+
},
435+
ExpectedRequest: expectedRequest,
436+
Response: map[string]string{
437+
"token_type": "Bearer",
438+
"access_token": tt.expectedAccessToken,
439+
},
440+
},
441+
},
442+
})
443+
444+
token, err := ts.Token(ctx)
445+
if err != nil {
446+
t.Fatalf("Token(ctx): got error %q, want none", err)
447+
}
448+
if token.AccessToken != tt.expectedAccessToken {
449+
t.Errorf("Token(ctx): got access token %q, want %q", token.AccessToken, tt.expectedAccessToken)
450+
}
451+
})
452+
}
453+
}

0 commit comments

Comments
 (0)