diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 549d68ab..3a151942 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -20,9 +20,11 @@ import ( "fmt" "os" "strings" + "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" + internalcache "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" @@ -259,6 +261,17 @@ type clientOptions struct { httpClient ops.HTTPClient } +// EnhancedClient provides thread-safe token caching with auto-renewal +type EnhancedClient struct { + Client + tokenCache *internalcache.TokenCache +} + +// EnhancedClientOptions contains options for creating an enhanced client +type EnhancedClientOptions struct { + RenewalBuffer time.Duration // How long before expiry to renew tokens +} + // Option is an optional argument to New(). type Option func(o *clientOptions) @@ -357,6 +370,37 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client return Client{base: base, cred: internalCred}, nil } +// NewEnhancedClient creates a new confidential client with enhanced token caching +func NewEnhancedClient(authority, clientID string, cred Credential, options ...Option) (EnhancedClient, error) { + client, err := New(authority, clientID, cred, options...) + if err != nil { + return EnhancedClient{}, err + } + + // Create token cache with default 2-minute renewal buffer + tokenCache := internalcache.NewTokenCache(2 * time.Minute) + + return EnhancedClient{ + Client: client, + tokenCache: tokenCache, + }, nil +} + +// NewEnhancedClientWithOptions creates a new enhanced client with custom options +func NewEnhancedClientWithOptions(authority, clientID string, cred Credential, enhancedOpts EnhancedClientOptions, options ...Option) (EnhancedClient, error) { + client, err := New(authority, clientID, cred, options...) + if err != nil { + return EnhancedClient{}, err + } + + tokenCache := internalcache.NewTokenCache(enhancedOpts.RenewalBuffer) + + return EnhancedClient{ + Client: client, + tokenCache: tokenCache, + }, nil +} + // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { claims, loginHint, tenantID, domainHint string @@ -743,6 +787,99 @@ func (cca Client) AcquireTokenByCredential(ctx context.Context, scopes []string, return cca.base.AuthResultFromToken(ctx, authParams, token) } +// AcquireTokenByCredentialWithCaching acquires a token using client credentials with enhanced caching +// This method automatically handles token reuse and renewal +func (eca EnhancedClient) AcquireTokenByCredentialWithCaching(ctx context.Context, scopes []string, opts ...AcquireByCredentialOption) (AuthResult, error) { + // Extract tenant ID from options + o := acquireTokenByCredentialOptions{} + if err := options.ApplyOptions(&o, opts); err != nil { + return AuthResult{}, err + } + + tenantID := o.tenantID + if tenantID == "" { + // Use default tenant from client + tenantID = eca.base.AuthParams.AuthorityInfo.Tenant + } + + // Check if we have a valid cached token + if cachedToken := eca.tokenCache.GetToken(scopes, tenantID); cachedToken != "" { + // Return cached token result + return AuthResult{ + AccessToken: cachedToken, + ExpiresOn: eca.getCachedTokenExpiry(scopes, tenantID), + GrantedScopes: scopes, + Metadata: base.AuthResultMetadata{ + TokenSource: base.TokenSourceCache, + }, + }, nil + } + + // No valid cached token, acquire new one + result, err := eca.AcquireTokenByCredential(ctx, scopes, opts...) + if err != nil { + return AuthResult{}, err + } + + // Cache the new token + eca.tokenCache.SetToken(scopes, tenantID, result.AccessToken, result.ExpiresOn) + + return result, nil +} + +// ForceRefreshToken clears cache and acquires a new token +func (eca EnhancedClient) ForceRefreshToken(ctx context.Context, scopes []string, opts ...AcquireByCredentialOption) (AuthResult, error) { + // Extract tenant ID from options + o := acquireTokenByCredentialOptions{} + if err := options.ApplyOptions(&o, opts); err != nil { + return AuthResult{}, err + } + + tenantID := o.tenantID + if tenantID == "" { + tenantID = eca.base.AuthParams.AuthorityInfo.Tenant + } + + // Clear cached token + eca.tokenCache.ClearToken(scopes, tenantID) + + // Acquire new token + return eca.AcquireTokenByCredentialWithCaching(ctx, scopes, opts...) +} + +// ClearTokenCache removes all cached tokens +func (eca EnhancedClient) ClearTokenCache() { + eca.tokenCache.ClearAll() +} + +// IsTokenCached checks if a valid token is cached +func (eca EnhancedClient) IsTokenCached(scopes []string, tenantID string) bool { + if tenantID == "" { + tenantID = eca.base.AuthParams.AuthorityInfo.Tenant + } + return eca.tokenCache.IsTokenValid(scopes, tenantID) +} + +// GetCacheStats returns statistics about the token cache +func (eca EnhancedClient) GetCacheStats() map[string]interface{} { + return eca.tokenCache.GetCacheStats() +} + +// getCachedTokenExpiry returns the expiry time of a cached token +func (eca EnhancedClient) getCachedTokenExpiry(scopes []string, tenantID string) time.Time { + if tenantID == "" { + tenantID = eca.base.AuthParams.AuthorityInfo.Tenant + } + + // Get the actual cached token data + cachedData := eca.tokenCache.GetCachedTokenData(scopes, tenantID) + if cachedData == nil { + return time.Time{} // Zero time if no cached data + } + + return cachedData.ExpiresAt +} + // acquireTokenOnBehalfOfOptions contains optional configuration for AcquireTokenOnBehalfOf type acquireTokenOnBehalfOfOptions struct { claims, tenantID string diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 165a662f..cfc478a7 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -173,6 +173,116 @@ func TestAcquireTokenByCredential(t *testing.T) { } } +func TestEnhancedClientTokenCaching(t *testing.T) { + cred, err := NewCredFromSecret("test_secret") + if err != nil { + t.Fatal(err) + } + + client, err := NewEnhancedClient("https://login.microsoftonline.com/common", "test_client_id", cred) + if err != nil { + t.Fatal(err) + } + + scopes := []string{"https://graph.microsoft.com/.default"} + ctx := context.Background() + + // First call should acquire new token + result1, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes) + if err != nil { + t.Fatalf("First token acquisition failed: %v", err) + } + + // Second call should return cached token + result2, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes) + if err != nil { + t.Fatalf("Second token acquisition failed: %v", err) + } + + // Tokens should be the same (cached) + if result1.AccessToken != result2.AccessToken { + t.Error("Expected cached token, but got different token") + } + + // Check cache stats + stats := client.GetCacheStats() + if stats["valid_tokens"].(int) != 1 { + t.Errorf("Expected 1 valid token in cache, got %v", stats["valid_tokens"]) + } +} + +func TestEnhancedClientForceRefresh(t *testing.T) { + cred, err := NewCredFromSecret("test_secret") + if err != nil { + t.Fatal(err) + } + + client, err := NewEnhancedClient("https://login.microsoftonline.com/common", "test_client_id", cred) + if err != nil { + t.Fatal(err) + } + + scopes := []string{"https://graph.microsoft.com/.default"} + ctx := context.Background() + + // Acquire initial token + result1, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes) + if err != nil { + t.Fatal(err) + } + + // Force refresh should get a new token + result2, err := client.ForceRefreshToken(ctx, scopes) + if err != nil { + t.Fatal(err) + } + + // Tokens should be different + if result1.AccessToken == result2.AccessToken { + t.Error("Expected different tokens after force refresh") + } +} + +func TestEnhancedClientCacheExpiry(t *testing.T) { + // Create client with very short renewal buffer for testing + cred, err := NewCredFromSecret("test_secret") + if err != nil { + t.Fatal(err) + } + + opts := EnhancedClientOptions{ + RenewalBuffer: 1 * time.Millisecond, // Very short buffer for testing + } + + client, err := NewEnhancedClientWithOptions("https://login.microsoftonline.com/common", "test_client_id", cred, opts) + if err != nil { + t.Fatal(err) + } + + scopes := []string{"https://graph.microsoft.com/.default"} + ctx := context.Background() + + // Acquire token + result1, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes) + if err != nil { + t.Fatal(err) + } + + // Wait for token to be considered expired + time.Sleep(10 * time.Millisecond) + + // Next call should acquire new token due to expiry + result2, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes) + if err != nil { + t.Fatal(err) + } + + // Tokens should be different due to expiry + if result1.AccessToken == result2.AccessToken { + t.Error("Expected different tokens after expiry") + } +} + func TestRegionAutoEnable_EmptyRegion_EnvRegion(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { diff --git a/apps/confidential/examples_test.go b/apps/confidential/examples_test.go index 27d9f495..a87a9b87 100644 --- a/apps/confidential/examples_test.go +++ b/apps/confidential/examples_test.go @@ -59,3 +59,53 @@ func ExampleNewCredFromCert_pem() { } fmt.Println(cred) // Simply here so cred is used, otherwise won't compile. } + +// This example demonstrates the enhanced client with automatic token caching and renewal +func ExampleEnhancedClient() { + // Create credential + cred, err := confidential.NewCredFromSecret("client_secret") + if err != nil { + log.Fatal(err) + } + + // Create enhanced client with automatic token caching + client, err := confidential.NewEnhancedClient( + "https://login.microsoftonline.com/your_tenant", + "client_id", + cred, + ) + if err != nil { + log.Fatal(err) + } + + scopes := []string{"https://graph.microsoft.com/.default"} + ctx := context.Background() + + // First call acquires and caches token + token1, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes) + if err != nil { + log.Fatal(err) + } + fmt.Printf("First token: %s\n", token1.AccessToken) + + // Second call returns cached token (no network request) + token2, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Second token (cached): %s\n", token2.AccessToken) + + // Check if tokens are the same (cached) + fmt.Printf("Tokens are same: %t\n", token1.AccessToken == token2.AccessToken) + + // Force refresh to get new token + token3, err := client.ForceRefreshToken(ctx, scopes) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Force refreshed token: %s\n", token3.AccessToken) + + // Get cache statistics + stats := client.GetCacheStats() + fmt.Printf("Cache stats: %+v\n", stats) +} diff --git a/apps/internal/cache/token_cache.go b/apps/internal/cache/token_cache.go new file mode 100644 index 00000000..45f33e68 --- /dev/null +++ b/apps/internal/cache/token_cache.go @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package cache + +import ( + "fmt" + "strings" + "sync" + "time" +) + +// TokenCache provides thread-safe token caching with auto-renewal +type TokenCache struct { + mu sync.RWMutex + tokens map[string]*CachedToken + renewalBuffer time.Duration // How long before expiry to renew (default: 2 minutes) +} + +// CachedToken represents a cached token with metadata +type CachedToken struct { + Token string + ExpiresAt time.Time + Scopes []string + TenantID string + CreatedAt time.Time +} + +// NewTokenCache creates a new thread-safe token cache +func NewTokenCache(renewalBuffer time.Duration) *TokenCache { + if renewalBuffer <= 0 { + renewalBuffer = 2 * time.Minute // Default 2-minute buffer + } + return &TokenCache{ + tokens: make(map[string]*CachedToken), + renewalBuffer: renewalBuffer, + } +} + +// GetToken retrieves a valid token from cache, returns empty string if not found or expired +func (tc *TokenCache) GetToken(scopes []string, tenantID string) string { + key := tc.generateKey(scopes, tenantID) + + tc.mu.RLock() + cached, exists := tc.tokens[key] + tc.mu.RUnlock() + + if !exists { + return "" + } + + // Check if token is still valid (with renewal buffer) + if time.Now().Add(tc.renewalBuffer).After(cached.ExpiresAt) { + // Token is expired or about to expire, remove it + tc.mu.Lock() + delete(tc.tokens, key) + tc.mu.Unlock() + return "" + } + + return cached.Token +} + +// GetCachedTokenData returns the full cached token data +func (tc *TokenCache) GetCachedTokenData(scopes []string, tenantID string) *CachedToken { + tc.mu.RLock() + defer tc.mu.RUnlock() + + key := tc.generateKey(scopes, tenantID) + if data, exists := tc.tokens[key]; exists && data.ExpiresAt.After(time.Now()) { + return data + } + return nil +} + +// SetToken stores a token in the cache +func (tc *TokenCache) SetToken(scopes []string, tenantID, token string, expiresAt time.Time) { + key := tc.generateKey(scopes, tenantID) + + tc.mu.Lock() + tc.tokens[key] = &CachedToken{ + Token: token, + ExpiresAt: expiresAt, + Scopes: scopes, + TenantID: tenantID, + CreatedAt: time.Now(), + } + tc.mu.Unlock() +} + +// ClearToken removes a specific token from cache +func (tc *TokenCache) ClearToken(scopes []string, tenantID string) { + key := tc.generateKey(scopes, tenantID) + + tc.mu.Lock() + delete(tc.tokens, key) + tc.mu.Unlock() +} + +// ClearAll removes all tokens from cache +func (tc *TokenCache) ClearAll() { + tc.mu.Lock() + tc.tokens = make(map[string]*CachedToken) + tc.mu.Unlock() +} + +// IsTokenValid checks if a token exists and is valid +func (tc *TokenCache) IsTokenValid(scopes []string, tenantID string) bool { + key := tc.generateKey(scopes, tenantID) + + tc.mu.RLock() + cached, exists := tc.tokens[key] + tc.mu.RUnlock() + + if !exists { + return false + } + + return time.Now().Add(tc.renewalBuffer).Before(cached.ExpiresAt) +} + +// generateKey creates a unique key for the token cache +func (tc *TokenCache) generateKey(scopes []string, tenantID string) string { + // Create a deterministic key from scopes and tenant + scopeStr := strings.Join(scopes, ",") + return fmt.Sprintf("%s|%s", tenantID, scopeStr) +} + +// GetCacheStats returns statistics about the cache +func (tc *TokenCache) GetCacheStats() map[string]interface{} { + tc.mu.RLock() + defer tc.mu.RUnlock() + + now := time.Now() + validCount := 0 + expiredCount := 0 + + for _, token := range tc.tokens { + if now.Add(tc.renewalBuffer).Before(token.ExpiresAt) { + validCount++ + } else { + expiredCount++ + } + } + + return map[string]interface{}{ + "total_tokens": len(tc.tokens), + "valid_tokens": validCount, + "expired_tokens": expiredCount, + "renewal_buffer": tc.renewalBuffer.String(), + } +}