Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 1 addition & 27 deletions config/auth_azure_github_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
)

Expand All @@ -28,7 +27,7 @@ func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config)
return nil, nil
}

idToken, err := requestIDToken(ctx, cfg)
idToken, err := cfg.getAllOIDCSuppliers().GetOIDCToken(ctx, "api://AzureADTokenExchange")
if err != nil {
return nil, err
}
Expand All @@ -47,31 +46,6 @@ func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config)
return credentials.NewOAuthCredentialsProvider(refreshableVisitor(ts), ts.Token), nil
}

// requestIDToken requests an ID token from the Github Action.
func requestIDToken(ctx context.Context, cfg *Config) (string, error) {
if cfg.ActionsIDTokenRequestURL == "" {
logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestURL, likely not calling from a Github action")
return "", nil
}
if cfg.ActionsIDTokenRequestToken == "" {
logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestToken, likely not calling from a Github action")
return "", nil
}

resp := struct { // anonymous struct to parse the response
Value string `json:"value"`
}{}
err := cfg.refreshClient.Do(ctx, "GET", fmt.Sprintf("%s&audience=api://AzureADTokenExchange", cfg.ActionsIDTokenRequestURL),
httpclient.WithRequestHeader("Authorization", fmt.Sprintf("Bearer %s", cfg.ActionsIDTokenRequestToken)),
httpclient.WithResponseUnmarshal(&resp),
)
if err != nil {
return "", fmt.Errorf("failed to request ID token from %s: %w", cfg.ActionsIDTokenRequestURL, err)
}

return resp.Value, nil
}

// azureOIDCTokenSource implements [oauth2.TokenSource] to obtain Azure auth
// tokens from an ID token.
type azureOIDCTokenSource struct {
Expand Down
65 changes: 65 additions & 0 deletions config/auth_databricks_oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package config

import (
"context"
"net/url"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

const jwtBearerGrantTypeURN = "urn:ietf:params:oauth:grant-type:jwt-bearer"

type DatabricksOIDCCredentials struct{}

// Configure implements CredentialsStrategy.
func (d DatabricksOIDCCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.Host == "" || cfg.ClientID == "" {
return nil, nil
}

// Get the OIDC token from the environment.
// TODO: trim the first 8 characters (https://) from the host
audience := cfg.CanonicalHostName()
if cfg.IsAccountClient() {
audience = cfg.AccountID
}
idToken, err := cfg.getAllOIDCSuppliers().GetOIDCToken(ctx, audience)
if err != nil {
return nil, err
}
if idToken == "" {
logger.Debugf(ctx, "No OIDC token found")
return nil, nil
}

endpoints, err := oidcEndpoints(ctx, cfg)
if err != nil {
return nil, err
}

tsConfig := clientcredentials.Config{
ClientID: cfg.ClientID,
ClientSecret: "",
AuthStyle: oauth2.AuthStyleInParams,
TokenURL: endpoints.TokenEndpoint,
Scopes: []string{"all-apis"},
EndpointParams: url.Values{
"grant_type": {jwtBearerGrantTypeURN},
"assertion": {idToken},
},
}
ts := tsConfig.TokenSource(httpclient.WithDebug(ctx, true))
visitor := refreshableVisitor(ts)
return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil
}

// Name implements CredentialsStrategy.
func (d DatabricksOIDCCredentials) Name() string {
return "inhouse-oidc"
}

var _ CredentialsStrategy = DatabricksOIDCCredentials{}
1 change: 1 addition & 0 deletions config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ var authProviders = []CredentialsStrategy{
PatCredentials{},
BasicCredentials{},
M2mCredentials{},
DatabricksOIDCCredentials{},
DatabricksCliCredentials{},
MetadataServiceCredentials{},

Expand Down
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func (c *Config) IsAccountClient() bool {
return true
}
}
return false
return strings.HasPrefix(c.Host, "https://accounts-")
}

func (c *Config) EnsureResolved() error {
Expand Down
99 changes: 99 additions & 0 deletions config/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package config

import (
"context"
"fmt"
"os"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
)

type oidcTokenSupplier interface {
Name() string

// GetOIDCToken returns an OIDC token for the given audience.
GetOIDCToken(ctx context.Context, audience string) (string, error)
}

type githubOIDCTokenSupplier struct {
idTokenRequestURL string
idTokenRequestToken string
client *httpclient.ApiClient
}

func githubOIDCTokenSupplierFromConfig(cfg *Config) githubOIDCTokenSupplier {
return githubOIDCTokenSupplier{
idTokenRequestURL: cfg.ActionsIDTokenRequestURL,
idTokenRequestToken: cfg.ActionsIDTokenRequestToken,
client: cfg.refreshClient,
}
}

func (g githubOIDCTokenSupplier) Name() string {
return "github"
}

// requestIDToken requests an ID token from the Github Action.
func (g githubOIDCTokenSupplier) GetOIDCToken(ctx context.Context, audience string) (string, error) {
if g.idTokenRequestURL == "" {
logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestURL, likely not calling from a Github action")
return "", nil
}
if g.idTokenRequestToken == "" {
logger.Debugf(ctx, "Missing cfg.ActionsIDTokenRequestToken, likely not calling from a Github action")
return "", nil
}
url := g.idTokenRequestURL
if audience != "" {
url = fmt.Sprintf("%s&audience=%s", url, audience)
}
resp := struct { // anonymous struct to parse the response
Value string `json:"value"`
}{}
err := g.client.Do(ctx, "GET", url,
httpclient.WithRequestHeader("Authorization", fmt.Sprintf("Bearer %s", g.idTokenRequestToken)),
httpclient.WithResponseUnmarshal(&resp),
)
if err != nil {
return "", fmt.Errorf("failed to request ID token from %s: %w", g.idTokenRequestURL, err)
}

return resp.Value, nil
}

var _ oidcTokenSupplier = githubOIDCTokenSupplier{}

type azureDevOpsOIDCTokenSupplier struct{}

func (a azureDevOpsOIDCTokenSupplier) Name() string {
return "azure-devops"
}

func (a azureDevOpsOIDCTokenSupplier) GetOIDCToken(ctx context.Context, audience string) (string, error) {
return os.Getenv("idToken"), nil
}

type oidcTokenSuppliers []oidcTokenSupplier

func (c *Config) getAllOIDCSuppliers() oidcTokenSuppliers {
return []oidcTokenSupplier{
githubOIDCTokenSupplierFromConfig(c),
azureDevOpsOIDCTokenSupplier{},
}
}

func (o oidcTokenSuppliers) GetOIDCToken(ctx context.Context, audience string) (string, error) {
for _, s := range o {
token, err := s.GetOIDCToken(ctx, audience)
if err != nil {
return "", err
}
if token != "" {
logger.Debugf(ctx, "OIDC token found from %s", s.Name())
return token, nil
}
logger.Debugf(ctx, "No OIDC token found from %s", s.Name())
}
return "", nil
}
29 changes: 25 additions & 4 deletions httpclient/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -327,16 +328,36 @@ func (c *ApiClient) recordRequestLog(
logger.Debugf(ctx, "%s", message)
}

type debugKeyType int

const debugKey debugKeyType = 1

func WithDebug(ctx context.Context, debug bool) context.Context {
return context.WithValue(ctx, debugKey, debug)
}

func IsDebug(ctx context.Context) bool {
debug, ok := ctx.Value(debugKey).(bool)
return ok && debug
}

func getDebugBody(ctx context.Context, body io.Reader) (io.Reader, []byte) {
if IsDebug(ctx) {
debugBytes, _ := io.ReadAll(body)
return strings.NewReader(string(debugBytes)), debugBytes
}
return body, []byte("<http.RoundTripper>")
}

// RoundTrip implements http.RoundTripper to integrate with golang.org/x/oauth2
func (c *ApiClient) RoundTrip(request *http.Request) (*http.Response, error) {
ctx := request.Context()
requestURL := request.URL.String()
body, debugBytes := getDebugBody(ctx, request.Body)
resp, err := retries.Poll(ctx, c.config.RetryTimeout,
c.attempt(ctx, request.Method, requestURL, common.RequestBody{
Reader: request.Body,
// DO NOT DECODE BODY, because it may contain sensitive payload,
// like Azure Service Principal in a multipart/form-data body.
DebugBytes: []byte("<http.RoundTripper>"),
Reader: body,
DebugBytes: debugBytes,
}, func(r *http.Request) error {
r.Header = request.Header
return nil
Expand Down