Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions auth/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ var databricksAWSDomains []string = []string{
}

var databricksAzureDomains []string = []string{
".staging.azuredatabricks.net",
".dev.azuredatabricks.net",
".azuredatabricks.net",
".databricks.azure.cn",
".databricks.azure.us",
Expand Down
86 changes: 86 additions & 0 deletions auth/tokenprovider/cached.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package tokenprovider

import (
"context"
"fmt"
"sync"
"time"

"github.com/rs/zerolog/log"
)

// CachedTokenProvider wraps another provider and caches tokens
type CachedTokenProvider struct {
provider TokenProvider
cache *Token
mutex sync.RWMutex
// RefreshThreshold determines when to refresh (default 5 minutes before expiry)
RefreshThreshold time.Duration
}

// NewCachedTokenProvider creates a caching wrapper around any token provider
func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider {
return &CachedTokenProvider{
provider: provider,
RefreshThreshold: 5 * time.Minute,
}
}

// GetToken retrieves a token, using cache if available and valid
func (p *CachedTokenProvider) GetToken(ctx context.Context) (*Token, error) {
// Try to get from cache first
p.mutex.RLock()
cached := p.cache
p.mutex.RUnlock()

if cached != nil && !p.shouldRefresh(cached) {
log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name())
return cached, nil
}

// Need to refresh
p.mutex.Lock()
defer p.mutex.Unlock()

// Double-check after acquiring write lock
if p.cache != nil && !p.shouldRefresh(p.cache) {
return p.cache, nil
}

log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name())
token, err := p.provider.GetToken(ctx)
if err != nil {
return nil, fmt.Errorf("cached token provider: failed to get token: %w", err)
}

p.cache = token
return token, nil
}

// shouldRefresh determines if a token should be refreshed
func (p *CachedTokenProvider) shouldRefresh(token *Token) bool {
if token == nil {
return true
}

// If no expiry time, assume token doesn't expire
if token.ExpiresAt.IsZero() {
return false
}

// Refresh if within threshold of expiry
refreshAt := token.ExpiresAt.Add(-p.RefreshThreshold)
return time.Now().After(refreshAt)
}

// Name returns the provider name
func (p *CachedTokenProvider) Name() string {
return fmt.Sprintf("cached[%s]", p.provider.Name())
}

// ClearCache clears the cached token
func (p *CachedTokenProvider) ClearCache() {
p.mutex.Lock()
p.cache = nil
p.mutex.Unlock()
}
204 changes: 204 additions & 0 deletions auth/tokenprovider/exchange.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
package tokenprovider

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
)

// FederationProvider wraps another token provider and automatically handles token exchange
type FederationProvider struct {
baseProvider TokenProvider
databricksHost string
clientID string // For SP-wide federation
httpClient *http.Client
// Settings for token exchange
returnOriginalTokenIfAuthenticated bool
}

// NewFederationProvider creates a federation provider that wraps another provider
// It automatically detects when token exchange is needed and falls back gracefully
func NewFederationProvider(baseProvider TokenProvider, databricksHost string) *FederationProvider {
return &FederationProvider{
baseProvider: baseProvider,
databricksHost: databricksHost,
httpClient: &http.Client{Timeout: 30 * time.Second},
returnOriginalTokenIfAuthenticated: true,
}
}

// NewFederationProviderWithClientID creates a provider for SP-wide federation (M2M)
func NewFederationProviderWithClientID(baseProvider TokenProvider, databricksHost, clientID string) *FederationProvider {
return &FederationProvider{
baseProvider: baseProvider,
databricksHost: databricksHost,
clientID: clientID,
httpClient: &http.Client{Timeout: 30 * time.Second},
returnOriginalTokenIfAuthenticated: true,
}
}

// GetToken gets token from base provider and exchanges if needed
func (p *FederationProvider) GetToken(ctx context.Context) (*Token, error) {
// Get token from base provider
baseToken, err := p.baseProvider.GetToken(ctx)
if err != nil {
return nil, fmt.Errorf("federation provider: failed to get base token: %w", err)
}

// Check if token is a JWT and needs exchange
if p.needsTokenExchange(baseToken.AccessToken) {
log.Debug().Msgf("federation provider: attempting token exchange for %s", p.baseProvider.Name())

// Try token exchange
exchangedToken, err := p.tryTokenExchange(ctx, baseToken.AccessToken)
if err != nil {
log.Warn().Err(err).Msg("federation provider: token exchange failed, using original token")
return baseToken, nil // Fall back to original token
}

log.Debug().Msg("federation provider: token exchange successful")
return exchangedToken, nil
}

// Use original token
return baseToken, nil
}

// needsTokenExchange determines if a token needs exchange by checking if it's from a different issuer
func (p *FederationProvider) needsTokenExchange(tokenString string) bool {
// Try to parse as JWT
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
log.Debug().Err(err).Msg("federation provider: not a JWT token, skipping exchange")
return false
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return false
}

issuer, ok := claims["iss"].(string)
if !ok {
return false
}

// Check if issuer is different from Databricks host
return !p.isSameHost(issuer, p.databricksHost)
}

// tryTokenExchange attempts to exchange the token with Databricks
func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken string) (*Token, error) {
// Build exchange URL - add scheme if not present
exchangeURL := p.databricksHost
if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") {
exchangeURL = "https://" + exchangeURL
}
if !strings.HasSuffix(exchangeURL, "/") {
exchangeURL += "/"
}
exchangeURL += "oidc/v1/token"

// Prepare form data for token exchange
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
data.Set("scope", "sql")
data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:jwt")
data.Set("subject_token", subjectToken)

if p.returnOriginalTokenIfAuthenticated {
data.Set("return_original_token_if_authenticated", "true")
}

// Add client_id for SP-wide federation
if p.clientID != "" {
data.Set("client_id", p.clientID)
}

// Create request
req, err := http.NewRequestWithContext(ctx, "POST", exchangeURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "*/*")

// Make request
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("exchange failed with status %d: %s", resp.StatusCode, string(body))
}

// Parse response
var tokenResp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}

if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}

token := &Token{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
Scopes: strings.Fields(tokenResp.Scope),
}

if tokenResp.ExpiresIn > 0 {
token.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
}

return token, nil
}

// isSameHost compares two URLs to see if they have the same host
func (p *FederationProvider) isSameHost(url1, url2 string) bool {
// Add scheme to url2 if it doesn't have one (databricksHost may not have scheme)
parsedURL2 := url2
if !strings.HasPrefix(url2, "http://") && !strings.HasPrefix(url2, "https://") {
parsedURL2 = "https://" + url2
}

u1, err1 := url.Parse(url1)
u2, err2 := url.Parse(parsedURL2)

if err1 != nil || err2 != nil {
return false
}

// Use Hostname() instead of Host to ignore port differences
// This handles cases like "host.com:443" == "host.com" for HTTPS
return u1.Hostname() == u2.Hostname()
}

// Name returns the provider name
func (p *FederationProvider) Name() string {
baseName := p.baseProvider.Name()
if p.clientID != "" {
return fmt.Sprintf("federation[%s,sp:%s]", baseName, p.clientID[:8]) // Truncate client ID for readability
}
return fmt.Sprintf("federation[%s]", baseName)
}
Loading
Loading