Skip to content

Commit 1fb8c1a

Browse files
madhav-dbclaude
andcommitted
Add token caching and federation support
Adds automatic token exchange (federation) and caching capabilities: - CachedTokenProvider: Automatic token refresh with 5min buffer - FederationProvider: Auto-detects and exchanges external JWT tokens - Supports both user federation and SP-wide (M2M) federation - Graceful fallback if token exchange unavailable - Connector functions: WithFederatedTokenProvider, WithFederatedTokenProviderAndClientID - Azure domain list updates for staging/dev environments Token exchange follows RFC 8693 standard. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent e857292 commit 1fb8c1a

File tree

5 files changed

+663
-0
lines changed

5 files changed

+663
-0
lines changed

auth/oauth/oauth.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ var databricksAWSDomains []string = []string{
8585
}
8686

8787
var databricksAzureDomains []string = []string{
88+
".staging.azuredatabricks.net",
89+
".dev.azuredatabricks.net",
8890
".azuredatabricks.net",
8991
".databricks.azure.cn",
9092
".databricks.azure.us",

auth/tokenprovider/cached.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"time"
8+
9+
"github.com/rs/zerolog/log"
10+
)
11+
12+
// CachedTokenProvider wraps another provider and caches tokens
13+
type CachedTokenProvider struct {
14+
provider TokenProvider
15+
cache *Token
16+
mutex sync.RWMutex
17+
// RefreshThreshold determines when to refresh (default 5 minutes before expiry)
18+
RefreshThreshold time.Duration
19+
}
20+
21+
// NewCachedTokenProvider creates a caching wrapper around any token provider
22+
func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider {
23+
return &CachedTokenProvider{
24+
provider: provider,
25+
RefreshThreshold: 5 * time.Minute,
26+
}
27+
}
28+
29+
// GetToken retrieves a token, using cache if available and valid
30+
func (p *CachedTokenProvider) GetToken(ctx context.Context) (*Token, error) {
31+
// Try to get from cache first
32+
p.mutex.RLock()
33+
cached := p.cache
34+
p.mutex.RUnlock()
35+
36+
if cached != nil && !p.shouldRefresh(cached) {
37+
log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name())
38+
return cached, nil
39+
}
40+
41+
// Need to refresh
42+
p.mutex.Lock()
43+
defer p.mutex.Unlock()
44+
45+
// Double-check after acquiring write lock
46+
if p.cache != nil && !p.shouldRefresh(p.cache) {
47+
return p.cache, nil
48+
}
49+
50+
log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name())
51+
token, err := p.provider.GetToken(ctx)
52+
if err != nil {
53+
return nil, fmt.Errorf("cached token provider: failed to get token: %w", err)
54+
}
55+
56+
p.cache = token
57+
return token, nil
58+
}
59+
60+
// shouldRefresh determines if a token should be refreshed
61+
func (p *CachedTokenProvider) shouldRefresh(token *Token) bool {
62+
if token == nil {
63+
return true
64+
}
65+
66+
// If no expiry time, assume token doesn't expire
67+
if token.ExpiresAt.IsZero() {
68+
return false
69+
}
70+
71+
// Refresh if within threshold of expiry
72+
refreshAt := token.ExpiresAt.Add(-p.RefreshThreshold)
73+
return time.Now().After(refreshAt)
74+
}
75+
76+
// Name returns the provider name
77+
func (p *CachedTokenProvider) Name() string {
78+
return fmt.Sprintf("cached[%s]", p.provider.Name())
79+
}
80+
81+
// ClearCache clears the cached token
82+
func (p *CachedTokenProvider) ClearCache() {
83+
p.mutex.Lock()
84+
p.cache = nil
85+
p.mutex.Unlock()
86+
}

auth/tokenprovider/exchange.go

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
package tokenprovider
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
"strings"
11+
"time"
12+
13+
"github.com/golang-jwt/jwt/v5"
14+
"github.com/rs/zerolog/log"
15+
)
16+
17+
// FederationProvider wraps another token provider and automatically handles token exchange
18+
type FederationProvider struct {
19+
baseProvider TokenProvider
20+
databricksHost string
21+
clientID string // For SP-wide federation
22+
httpClient *http.Client
23+
// Settings for token exchange
24+
returnOriginalTokenIfAuthenticated bool
25+
}
26+
27+
// NewFederationProvider creates a federation provider that wraps another provider
28+
// It automatically detects when token exchange is needed and falls back gracefully
29+
func NewFederationProvider(baseProvider TokenProvider, databricksHost string) *FederationProvider {
30+
return &FederationProvider{
31+
baseProvider: baseProvider,
32+
databricksHost: databricksHost,
33+
httpClient: &http.Client{Timeout: 30 * time.Second},
34+
returnOriginalTokenIfAuthenticated: true,
35+
}
36+
}
37+
38+
// NewFederationProviderWithClientID creates a provider for SP-wide federation (M2M)
39+
func NewFederationProviderWithClientID(baseProvider TokenProvider, databricksHost, clientID string) *FederationProvider {
40+
return &FederationProvider{
41+
baseProvider: baseProvider,
42+
databricksHost: databricksHost,
43+
clientID: clientID,
44+
httpClient: &http.Client{Timeout: 30 * time.Second},
45+
returnOriginalTokenIfAuthenticated: true,
46+
}
47+
}
48+
49+
// GetToken gets token from base provider and exchanges if needed
50+
func (p *FederationProvider) GetToken(ctx context.Context) (*Token, error) {
51+
// Get token from base provider
52+
baseToken, err := p.baseProvider.GetToken(ctx)
53+
if err != nil {
54+
return nil, fmt.Errorf("federation provider: failed to get base token: %w", err)
55+
}
56+
57+
// Check if token is a JWT and needs exchange
58+
if p.needsTokenExchange(baseToken.AccessToken) {
59+
log.Debug().Msgf("federation provider: attempting token exchange for %s", p.baseProvider.Name())
60+
61+
// Try token exchange
62+
exchangedToken, err := p.tryTokenExchange(ctx, baseToken.AccessToken)
63+
if err != nil {
64+
log.Warn().Err(err).Msg("federation provider: token exchange failed, using original token")
65+
return baseToken, nil // Fall back to original token
66+
}
67+
68+
log.Debug().Msg("federation provider: token exchange successful")
69+
return exchangedToken, nil
70+
}
71+
72+
// Use original token
73+
return baseToken, nil
74+
}
75+
76+
// needsTokenExchange determines if a token needs exchange by checking if it's from a different issuer
77+
func (p *FederationProvider) needsTokenExchange(tokenString string) bool {
78+
// Try to parse as JWT
79+
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
80+
if err != nil {
81+
log.Debug().Err(err).Msg("federation provider: not a JWT token, skipping exchange")
82+
return false
83+
}
84+
85+
claims, ok := token.Claims.(jwt.MapClaims)
86+
if !ok {
87+
return false
88+
}
89+
90+
issuer, ok := claims["iss"].(string)
91+
if !ok {
92+
return false
93+
}
94+
95+
// Check if issuer is different from Databricks host
96+
return !p.isSameHost(issuer, p.databricksHost)
97+
}
98+
99+
// tryTokenExchange attempts to exchange the token with Databricks
100+
func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken string) (*Token, error) {
101+
// Build exchange URL - add scheme if not present
102+
exchangeURL := p.databricksHost
103+
if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") {
104+
exchangeURL = "https://" + exchangeURL
105+
}
106+
if !strings.HasSuffix(exchangeURL, "/") {
107+
exchangeURL += "/"
108+
}
109+
exchangeURL += "oidc/v1/token"
110+
111+
// Prepare form data for token exchange
112+
data := url.Values{}
113+
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
114+
data.Set("scope", "sql")
115+
data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:jwt")
116+
data.Set("subject_token", subjectToken)
117+
118+
if p.returnOriginalTokenIfAuthenticated {
119+
data.Set("return_original_token_if_authenticated", "true")
120+
}
121+
122+
// Add client_id for SP-wide federation
123+
if p.clientID != "" {
124+
data.Set("client_id", p.clientID)
125+
}
126+
127+
// Create request
128+
req, err := http.NewRequestWithContext(ctx, "POST", exchangeURL, strings.NewReader(data.Encode()))
129+
if err != nil {
130+
return nil, fmt.Errorf("failed to create request: %w", err)
131+
}
132+
133+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
134+
req.Header.Set("Accept", "*/*")
135+
136+
// Make request
137+
resp, err := p.httpClient.Do(req)
138+
if err != nil {
139+
return nil, fmt.Errorf("request failed: %w", err)
140+
}
141+
defer resp.Body.Close()
142+
143+
body, err := io.ReadAll(resp.Body)
144+
if err != nil {
145+
return nil, fmt.Errorf("failed to read response: %w", err)
146+
}
147+
148+
if resp.StatusCode != http.StatusOK {
149+
return nil, fmt.Errorf("exchange failed with status %d: %s", resp.StatusCode, string(body))
150+
}
151+
152+
// Parse response
153+
var tokenResp struct {
154+
AccessToken string `json:"access_token"`
155+
TokenType string `json:"token_type"`
156+
ExpiresIn int `json:"expires_in"`
157+
Scope string `json:"scope"`
158+
}
159+
160+
if err := json.Unmarshal(body, &tokenResp); err != nil {
161+
return nil, fmt.Errorf("failed to parse response: %w", err)
162+
}
163+
164+
token := &Token{
165+
AccessToken: tokenResp.AccessToken,
166+
TokenType: tokenResp.TokenType,
167+
Scopes: strings.Fields(tokenResp.Scope),
168+
}
169+
170+
if tokenResp.ExpiresIn > 0 {
171+
token.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
172+
}
173+
174+
return token, nil
175+
}
176+
177+
// isSameHost compares two URLs to see if they have the same host
178+
func (p *FederationProvider) isSameHost(url1, url2 string) bool {
179+
// Add scheme to url2 if it doesn't have one (databricksHost may not have scheme)
180+
parsedURL2 := url2
181+
if !strings.HasPrefix(url2, "http://") && !strings.HasPrefix(url2, "https://") {
182+
parsedURL2 = "https://" + url2
183+
}
184+
185+
u1, err1 := url.Parse(url1)
186+
u2, err2 := url.Parse(parsedURL2)
187+
188+
if err1 != nil || err2 != nil {
189+
return false
190+
}
191+
192+
// Use Hostname() instead of Host to ignore port differences
193+
// This handles cases like "host.com:443" == "host.com" for HTTPS
194+
return u1.Hostname() == u2.Hostname()
195+
}
196+
197+
// Name returns the provider name
198+
func (p *FederationProvider) Name() string {
199+
baseName := p.baseProvider.Name()
200+
if p.clientID != "" {
201+
return fmt.Sprintf("federation[%s,sp:%s]", baseName, p.clientID[:8]) // Truncate client ID for readability
202+
}
203+
return fmt.Sprintf("federation[%s]", baseName)
204+
}

0 commit comments

Comments
 (0)