Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
"fmt"
"os"
"strings"
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
internalcache "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
Expand Down Expand Up @@ -259,6 +261,17 @@ type clientOptions struct {
httpClient ops.HTTPClient
}

// EnhancedClient provides thread-safe token caching with auto-renewal
type EnhancedClient struct {
Client
tokenCache *internalcache.TokenCache
}

// EnhancedClientOptions contains options for creating an enhanced client
type EnhancedClientOptions struct {
RenewalBuffer time.Duration // How long before expiry to renew tokens
}

// Option is an optional argument to New().
type Option func(o *clientOptions)

Expand Down Expand Up @@ -357,6 +370,37 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client
return Client{base: base, cred: internalCred}, nil
}

// NewEnhancedClient creates a new confidential client with enhanced token caching
func NewEnhancedClient(authority, clientID string, cred Credential, options ...Option) (EnhancedClient, error) {
client, err := New(authority, clientID, cred, options...)
if err != nil {
return EnhancedClient{}, err
}

// Create token cache with default 2-minute renewal buffer
tokenCache := internalcache.NewTokenCache(2 * time.Minute)

return EnhancedClient{
Client: client,
tokenCache: tokenCache,
}, nil
}

// NewEnhancedClientWithOptions creates a new enhanced client with custom options
func NewEnhancedClientWithOptions(authority, clientID string, cred Credential, enhancedOpts EnhancedClientOptions, options ...Option) (EnhancedClient, error) {
client, err := New(authority, clientID, cred, options...)
if err != nil {
return EnhancedClient{}, err
}

tokenCache := internalcache.NewTokenCache(enhancedOpts.RenewalBuffer)

return EnhancedClient{
Client: client,
tokenCache: tokenCache,
}, nil
}

// authCodeURLOptions contains options for AuthCodeURL
type authCodeURLOptions struct {
claims, loginHint, tenantID, domainHint string
Expand Down Expand Up @@ -743,6 +787,99 @@ func (cca Client) AcquireTokenByCredential(ctx context.Context, scopes []string,
return cca.base.AuthResultFromToken(ctx, authParams, token)
}

// AcquireTokenByCredentialWithCaching acquires a token using client credentials with enhanced caching
// This method automatically handles token reuse and renewal
func (eca EnhancedClient) AcquireTokenByCredentialWithCaching(ctx context.Context, scopes []string, opts ...AcquireByCredentialOption) (AuthResult, error) {
// Extract tenant ID from options
o := acquireTokenByCredentialOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return AuthResult{}, err
}

tenantID := o.tenantID
if tenantID == "" {
// Use default tenant from client
tenantID = eca.base.AuthParams.AuthorityInfo.Tenant
}

// Check if we have a valid cached token
if cachedToken := eca.tokenCache.GetToken(scopes, tenantID); cachedToken != "" {
// Return cached token result
return AuthResult{
AccessToken: cachedToken,
ExpiresOn: eca.getCachedTokenExpiry(scopes, tenantID),
GrantedScopes: scopes,
Metadata: base.AuthResultMetadata{
TokenSource: base.TokenSourceCache,
},
}, nil
}

// No valid cached token, acquire new one
result, err := eca.AcquireTokenByCredential(ctx, scopes, opts...)
if err != nil {
return AuthResult{}, err
}

// Cache the new token
eca.tokenCache.SetToken(scopes, tenantID, result.AccessToken, result.ExpiresOn)

return result, nil
}

// ForceRefreshToken clears cache and acquires a new token
func (eca EnhancedClient) ForceRefreshToken(ctx context.Context, scopes []string, opts ...AcquireByCredentialOption) (AuthResult, error) {
// Extract tenant ID from options
o := acquireTokenByCredentialOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return AuthResult{}, err
}

tenantID := o.tenantID
if tenantID == "" {
tenantID = eca.base.AuthParams.AuthorityInfo.Tenant
}

// Clear cached token
eca.tokenCache.ClearToken(scopes, tenantID)

// Acquire new token
return eca.AcquireTokenByCredentialWithCaching(ctx, scopes, opts...)
}

// ClearTokenCache removes all cached tokens
func (eca EnhancedClient) ClearTokenCache() {
eca.tokenCache.ClearAll()
}

// IsTokenCached checks if a valid token is cached
func (eca EnhancedClient) IsTokenCached(scopes []string, tenantID string) bool {
if tenantID == "" {
tenantID = eca.base.AuthParams.AuthorityInfo.Tenant
}
return eca.tokenCache.IsTokenValid(scopes, tenantID)
}

// GetCacheStats returns statistics about the token cache
func (eca EnhancedClient) GetCacheStats() map[string]interface{} {
return eca.tokenCache.GetCacheStats()
}

// getCachedTokenExpiry returns the expiry time of a cached token
func (eca EnhancedClient) getCachedTokenExpiry(scopes []string, tenantID string) time.Time {
if tenantID == "" {
tenantID = eca.base.AuthParams.AuthorityInfo.Tenant
}

// Get the actual cached token data
cachedData := eca.tokenCache.GetCachedTokenData(scopes, tenantID)
if cachedData == nil {
return time.Time{} // Zero time if no cached data
}

return cachedData.ExpiresAt
}

// acquireTokenOnBehalfOfOptions contains optional configuration for AcquireTokenOnBehalfOf
type acquireTokenOnBehalfOfOptions struct {
claims, tenantID string
Expand Down
110 changes: 110 additions & 0 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,116 @@ func TestAcquireTokenByCredential(t *testing.T) {
}
}

func TestEnhancedClientTokenCaching(t *testing.T) {
cred, err := NewCredFromSecret("test_secret")
if err != nil {
t.Fatal(err)
}

client, err := NewEnhancedClient("https://login.microsoftonline.com/common", "test_client_id", cred)
if err != nil {
t.Fatal(err)
}

scopes := []string{"https://graph.microsoft.com/.default"}
ctx := context.Background()

// First call should acquire new token
result1, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes)
if err != nil {
t.Fatalf("First token acquisition failed: %v", err)
}

// Second call should return cached token
result2, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes)
if err != nil {
t.Fatalf("Second token acquisition failed: %v", err)
}

// Tokens should be the same (cached)
if result1.AccessToken != result2.AccessToken {
t.Error("Expected cached token, but got different token")
}

// Check cache stats
stats := client.GetCacheStats()
if stats["valid_tokens"].(int) != 1 {
t.Errorf("Expected 1 valid token in cache, got %v", stats["valid_tokens"])
}
}

func TestEnhancedClientForceRefresh(t *testing.T) {
cred, err := NewCredFromSecret("test_secret")
if err != nil {
t.Fatal(err)
}

client, err := NewEnhancedClient("https://login.microsoftonline.com/common", "test_client_id", cred)
if err != nil {
t.Fatal(err)
}

scopes := []string{"https://graph.microsoft.com/.default"}
ctx := context.Background()

// Acquire initial token
result1, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes)
if err != nil {
t.Fatal(err)
}

// Force refresh should get a new token
result2, err := client.ForceRefreshToken(ctx, scopes)
if err != nil {
t.Fatal(err)
}

// Tokens should be different
if result1.AccessToken == result2.AccessToken {
t.Error("Expected different tokens after force refresh")
}
}

func TestEnhancedClientCacheExpiry(t *testing.T) {
// Create client with very short renewal buffer for testing
cred, err := NewCredFromSecret("test_secret")
if err != nil {
t.Fatal(err)
}

opts := EnhancedClientOptions{
RenewalBuffer: 1 * time.Millisecond, // Very short buffer for testing
}

client, err := NewEnhancedClientWithOptions("https://login.microsoftonline.com/common", "test_client_id", cred, opts)
if err != nil {
t.Fatal(err)
}

scopes := []string{"https://graph.microsoft.com/.default"}
ctx := context.Background()

// Acquire token
result1, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes)
if err != nil {
t.Fatal(err)
}

// Wait for token to be considered expired
time.Sleep(10 * time.Millisecond)

// Next call should acquire new token due to expiry
result2, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes)
if err != nil {
t.Fatal(err)
}

// Tokens should be different due to expiry
if result1.AccessToken == result2.AccessToken {
t.Error("Expected different tokens after expiry")
}
}

func TestRegionAutoEnable_EmptyRegion_EnvRegion(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
Expand Down
50 changes: 50 additions & 0 deletions apps/confidential/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,53 @@ func ExampleNewCredFromCert_pem() {
}
fmt.Println(cred) // Simply here so cred is used, otherwise won't compile.
}

// This example demonstrates the enhanced client with automatic token caching and renewal
func ExampleEnhancedClient() {
// Create credential
cred, err := confidential.NewCredFromSecret("client_secret")
if err != nil {
log.Fatal(err)
}

// Create enhanced client with automatic token caching
client, err := confidential.NewEnhancedClient(
"https://login.microsoftonline.com/your_tenant",
"client_id",
cred,
)
if err != nil {
log.Fatal(err)
}

scopes := []string{"https://graph.microsoft.com/.default"}
ctx := context.Background()

// First call acquires and caches token
token1, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes)
if err != nil {
log.Fatal(err)
}
fmt.Printf("First token: %s\n", token1.AccessToken)

// Second call returns cached token (no network request)
token2, err := client.AcquireTokenByCredentialWithCaching(ctx, scopes)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Second token (cached): %s\n", token2.AccessToken)

// Check if tokens are the same (cached)
fmt.Printf("Tokens are same: %t\n", token1.AccessToken == token2.AccessToken)

// Force refresh to get new token
token3, err := client.ForceRefreshToken(ctx, scopes)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Force refreshed token: %s\n", token3.AccessToken)

// Get cache statistics
stats := client.GetCacheStats()
fmt.Printf("Cache stats: %+v\n", stats)
}
Loading