Skip to content

Commit 466fc83

Browse files
M2M and WIF custom scopes support
1 parent 38bdb02 commit 466fc83

File tree

10 files changed

+467
-2
lines changed

10 files changed

+467
-2
lines changed

config/auth_default.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsSt
145145
TokenEndpointProvider: cfg.getOidcEndpoints,
146146
Audience: cfg.TokenAudience,
147147
IDTokenSource: ts,
148+
Scopes: cfg.GetScopes(),
148149
}
149150
if cfg.HostType() != WorkspaceHost {
150151
oidcConfig.AccountID = cfg.AccountID

config/auth_default_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@ package config
22

33
import (
44
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
58
"strings"
69
"testing"
10+
11+
"github.com/databricks/databricks-sdk-go/credentials/u2m"
712
)
813

914
func TestDefaultCredentials_Configure(t *testing.T) {
@@ -47,3 +52,100 @@ func TestDefaultCredentials_Configure(t *testing.T) {
4752
})
4853
}
4954
}
55+
56+
func TestGithubOIDC_Scopes(t *testing.T) {
57+
tests := []struct {
58+
name string
59+
scopes []string
60+
expectedScope string
61+
}{
62+
{
63+
name: "default scopes",
64+
scopes: nil,
65+
expectedScope: "all-apis",
66+
},
67+
{
68+
name: "custom scopes",
69+
scopes: []string{"clusters", "jobs"},
70+
expectedScope: "clusters jobs",
71+
},
72+
}
73+
74+
for _, tt := range tests {
75+
t.Run(tt.name, func(t *testing.T) {
76+
githubTokenCalled := false
77+
tokenExchangeCalled := false
78+
79+
// Simulates the GitHub Actions OIDC token endpoint.
80+
githubServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81+
githubTokenCalled = true
82+
w.Header().Set("Content-Type", "application/json")
83+
json.NewEncoder(w).Encode(map[string]string{"value": "github-id-token"})
84+
}))
85+
defer githubServer.Close()
86+
87+
// Simulates a Databricks workspace.
88+
// Asserts whether the right scopes are passed to the token exchange endpoint.
89+
var databricksServer *httptest.Server
90+
databricksServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91+
switch r.URL.Path {
92+
case "/oidc/.well-known/oauth-authorization-server":
93+
w.Header().Set("Content-Type", "application/json")
94+
json.NewEncoder(w).Encode(u2m.OAuthAuthorizationServer{
95+
AuthorizationEndpoint: "https://host.com/oidc/v1/authorize",
96+
TokenEndpoint: databricksServer.URL + "/oidc/v1/token",
97+
})
98+
99+
case "/oidc/v1/token":
100+
tokenExchangeCalled = true
101+
if err := r.ParseForm(); err != nil {
102+
t.Fatalf("Failed to parse form: %v", err)
103+
}
104+
// Verify scope is passed correctly to token exchange.
105+
if got := r.Form.Get("scope"); got != tt.expectedScope {
106+
t.Errorf("scope: got %q, want %q", got, tt.expectedScope)
107+
}
108+
w.Header().Set("Content-Type", "application/json")
109+
json.NewEncoder(w).Encode(map[string]interface{}{
110+
"token_type": "Bearer",
111+
"access_token": "databricks-access-token",
112+
"expires_in": 3600,
113+
})
114+
115+
default:
116+
t.Errorf("Unexpected request: %s %s", r.Method, r.URL.Path)
117+
http.Error(w, "Not found", http.StatusNotFound)
118+
}
119+
}))
120+
defer databricksServer.Close()
121+
122+
cfg := &Config{
123+
Host: databricksServer.URL,
124+
ClientID: "test-client-id",
125+
ActionsIDTokenRequestURL: githubServer.URL + "/github-token?version=1",
126+
ActionsIDTokenRequestToken: "github-request-token",
127+
TokenAudience: "databricks-test-audience",
128+
AuthType: "github-oidc",
129+
}
130+
if tt.scopes != nil {
131+
cfg.Scopes = tt.scopes
132+
}
133+
134+
req, _ := http.NewRequest("GET", databricksServer.URL+"/api/test", nil)
135+
err := cfg.Authenticate(req)
136+
if err != nil {
137+
t.Fatalf("Authenticate(): got error %v, want none", err)
138+
}
139+
140+
if got := req.Header.Get("Authorization"); got != "Bearer databricks-access-token" {
141+
t.Errorf("Authorization header: got %q, want %q", got, "Bearer databricks-access-token")
142+
}
143+
if !githubTokenCalled {
144+
t.Error("GitHub token endpoint was not called")
145+
}
146+
if !tokenExchangeCalled {
147+
t.Error("Token exchange endpoint was not called")
148+
}
149+
})
150+
}
151+
}

config/auth_m2m.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials
3131
ClientSecret: cfg.ClientSecret,
3232
AuthStyle: oauth2.AuthStyleInHeader,
3333
TokenURL: endpoints.TokenEndpoint,
34-
Scopes: []string{"all-apis"},
34+
Scopes: cfg.GetScopes(),
3535
}).TokenSource(ctx)
3636

3737
visitor := refreshableVisitor(ts)

config/auth_m2m_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package config
22

33
import (
44
"net/url"
5+
"sort"
6+
"strings"
57
"testing"
68

79
"github.com/databricks/databricks-sdk-go/credentials/u2m"
@@ -92,3 +94,148 @@ func TestM2mNotSupported(t *testing.T) {
9294
})
9395
require.ErrorIs(t, err, u2m.ErrOAuthNotSupported)
9496
}
97+
98+
func TestM2M_Scopes(t *testing.T) {
99+
tests := []struct {
100+
name string
101+
host string
102+
accountID string
103+
scopes []string
104+
wellKnownEndpoint string
105+
authServer *u2m.OAuthAuthorizationServer
106+
tokenEndpoint string
107+
expectedToken string
108+
}{
109+
{
110+
name: "default scopes when not configured",
111+
host: "a",
112+
scopes: nil,
113+
wellKnownEndpoint: "GET /oidc/.well-known/oauth-authorization-server",
114+
authServer: &u2m.OAuthAuthorizationServer{
115+
AuthorizationEndpoint: "https://localhost:1234/dummy/auth",
116+
TokenEndpoint: "https://localhost:1234/dummy/token",
117+
},
118+
tokenEndpoint: "POST /dummy/token",
119+
expectedToken: "test-token",
120+
},
121+
{
122+
name: "single scope",
123+
host: "a",
124+
scopes: []string{"dashboards"},
125+
wellKnownEndpoint: "GET /oidc/.well-known/oauth-authorization-server",
126+
authServer: &u2m.OAuthAuthorizationServer{
127+
AuthorizationEndpoint: "https://localhost:1234/dummy/auth",
128+
TokenEndpoint: "https://localhost:1234/dummy/token",
129+
},
130+
tokenEndpoint: "POST /dummy/token",
131+
expectedToken: "test-token",
132+
},
133+
{
134+
name: "multiple scopes sorted",
135+
host: "a",
136+
scopes: []string{"jobs", "files", "mlflow"},
137+
wellKnownEndpoint: "GET /oidc/.well-known/oauth-authorization-server",
138+
authServer: &u2m.OAuthAuthorizationServer{
139+
AuthorizationEndpoint: "https://localhost:1234/dummy/auth",
140+
TokenEndpoint: "https://localhost:1234/dummy/token",
141+
},
142+
tokenEndpoint: "POST /dummy/token",
143+
expectedToken: "test-token",
144+
},
145+
{
146+
name: "empty scopes uses default",
147+
host: "a",
148+
scopes: []string{},
149+
wellKnownEndpoint: "GET /oidc/.well-known/oauth-authorization-server",
150+
authServer: &u2m.OAuthAuthorizationServer{
151+
AuthorizationEndpoint: "https://localhost:1234/dummy/auth",
152+
TokenEndpoint: "https://localhost:1234/dummy/token",
153+
},
154+
tokenEndpoint: "POST /dummy/token",
155+
expectedToken: "test-token",
156+
},
157+
{
158+
name: "workspace host",
159+
host: "https://my-workspace.cloud.databricks.com",
160+
scopes: []string{"mlflow:read"},
161+
wellKnownEndpoint: "GET /oidc/.well-known/oauth-authorization-server",
162+
authServer: &u2m.OAuthAuthorizationServer{
163+
AuthorizationEndpoint: "https://my-workspace.cloud.databricks.com/oidc/v1/authorize",
164+
TokenEndpoint: "https://my-workspace.cloud.databricks.com/oidc/v1/token",
165+
},
166+
tokenEndpoint: "POST /oidc/v1/token",
167+
expectedToken: "workspace-token",
168+
},
169+
{
170+
name: "account host",
171+
host: "accounts.cloud.databricks.com",
172+
accountID: "my-account",
173+
scopes: []string{"iam", "jobs", "files"},
174+
tokenEndpoint: "POST /oidc/accounts/my-account/v1/token",
175+
expectedToken: "account-token",
176+
},
177+
}
178+
179+
for _, tt := range tests {
180+
t.Run(tt.name, func(t *testing.T) {
181+
// Scopes are expected as a space-separated string in requests.
182+
// We sort scopes in place during config resolution.
183+
expectedScope := "all-apis"
184+
if len(tt.scopes) > 0 {
185+
sortedScopes := make([]string, len(tt.scopes))
186+
copy(sortedScopes, tt.scopes)
187+
sort.Strings(sortedScopes)
188+
expectedScope = strings.Join(sortedScopes, " ")
189+
}
190+
191+
transport := fixtures.MappingTransport{}
192+
193+
// Add well-known endpoint if defined.
194+
if tt.wellKnownEndpoint != "" {
195+
transport[tt.wellKnownEndpoint] = fixtures.HTTPFixture{
196+
Response: tt.authServer,
197+
}
198+
}
199+
200+
// Add token endpoint.
201+
transport[tt.tokenEndpoint] = fixtures.HTTPFixture{
202+
ExpectedHeaders: map[string]string{
203+
"Authorization": "Basic Yjpj",
204+
"Content-Type": "application/x-www-form-urlencoded",
205+
},
206+
ExpectedRequest: url.Values{
207+
"grant_type": {"client_credentials"},
208+
"scope": {expectedScope},
209+
},
210+
Response: oauth2.Token{
211+
TokenType: "Bearer",
212+
AccessToken: tt.expectedToken,
213+
},
214+
}
215+
216+
cfg := &Config{
217+
Host: tt.host,
218+
ClientID: "b",
219+
ClientSecret: "c",
220+
AuthType: "oauth-m2m",
221+
HTTPTransport: transport,
222+
}
223+
224+
if tt.accountID != "" {
225+
cfg.AccountID = tt.accountID
226+
}
227+
228+
if tt.scopes != nil {
229+
cfg.Scopes = tt.scopes
230+
}
231+
232+
assertHeaders(
233+
t,
234+
cfg,
235+
map[string]string{
236+
"Authorization": "Bearer " + tt.expectedToken,
237+
},
238+
)
239+
})
240+
}
241+
}

config/config.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http"
88
"net/url"
99
"reflect"
10+
"sort"
1011
"strings"
1112
"sync"
1213
"time"
@@ -63,6 +64,8 @@ const (
6364
InvalidConfig ConfigType = "INVALID_CONFIG"
6465
)
6566

67+
var DefaultScopes = []string{"all-apis"}
68+
6669
// Config represents configuration for Databricks Connectivity
6770
type Config struct {
6871
// Credentials holds an instance of Credentials Strategy to authenticate with Databricks REST APIs.
@@ -135,6 +138,12 @@ type Config struct {
135138
ClientID string `name:"client_id" env:"DATABRICKS_CLIENT_ID" auth:"oauth" auth_types:"oauth-m2m"`
136139
ClientSecret string `name:"client_secret" env:"DATABRICKS_CLIENT_SECRET" auth:"oauth,sensitive" auth_types:"oauth-m2m"`
137140

141+
// Scopes is a list of OAuth scopes to request when authenticating.
142+
// If not specified, defaults to ["all-apis"] for backwards compatibility.
143+
// Note: Setting scopes via environment variables is not supported.
144+
// Note: The slice is sorted in-place during config resolution.
145+
Scopes []string `name:"scopes" auth:"-"`
146+
138147
// Path to the Databricks CLI (version >= 0.100.0).
139148
DatabricksCliPath string `name:"databricks_cli_path" env:"DATABRICKS_CLI_PATH" auth_types:"databricks-cli"`
140149

@@ -445,6 +454,11 @@ func (c *Config) EnsureResolved() error {
445454
},
446455
}
447456
}
457+
// Sort scopes in-place for better de-duplication in the refresh token cache,
458+
// once scopes are supported in its cache key.
459+
if len(c.Scopes) > 0 {
460+
sort.Strings(c.Scopes)
461+
}
448462
c.resolved = true
449463
return nil
450464
}
@@ -460,6 +474,13 @@ func (c *Config) CanonicalHostName() string {
460474
return c.Host
461475
}
462476

477+
func (c *Config) GetScopes() []string {
478+
if len(c.Scopes) == 0 {
479+
return DefaultScopes
480+
}
481+
return c.Scopes
482+
}
483+
463484
func (c *Config) wrapDebug(err error) error {
464485
debug := ConfigAttributes.DebugString(c)
465486
if debug == "" {

config/config_attribute.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"os"
66
"reflect"
77
"strconv"
8+
"strings"
89
)
910

1011
type Source struct {
@@ -69,6 +70,12 @@ func (a *ConfigAttribute) SetS(cfg *Config, v string) error {
6970
return err
7071
}
7172
return a.Set(cfg, vv)
73+
case reflect.Slice:
74+
parts := strings.Split(v, ",")
75+
for i, p := range parts {
76+
parts[i] = strings.TrimSpace(p)
77+
}
78+
return a.Set(cfg, parts)
7279
default:
7380
return fmt.Errorf("cannot set %s of unknown type %s",
7481
a.Name, reflectKind(a.Kind))
@@ -85,6 +92,8 @@ func (a *ConfigAttribute) Set(cfg *Config, i interface{}) error {
8592
field.SetBool(i.(bool))
8693
case reflect.Int:
8794
field.SetInt(int64(i.(int)))
95+
case reflect.Slice:
96+
field.Set(reflect.ValueOf(i.([]string)))
8897
default:
8998
// must extensively test with providerFixture to avoid this one
9099
return fmt.Errorf("cannot set %s of unknown type %s", a.Name, reflectKind(a.Kind))

0 commit comments

Comments
 (0)