diff --git a/pkg/storage/azure/azure.go b/pkg/storage/azure/azure.go index 1a732fd05..1990a242c 100644 --- a/pkg/storage/azure/azure.go +++ b/pkg/storage/azure/azure.go @@ -57,6 +57,10 @@ const ( azureCredentialsKey = "AzureCredentials" ) +// globalAzureCredentials caches User Assigned Managed Identity (UAMI) credentials across driver instances so that +// reconciles do not recreate the credential repeatedly. +var globalAzureCredentials sync.Map + // storageAccountInvalidCharRe is a regular expression for characters that // cannot be used in Azure storage accounts names (i.e. that are not // numbers nor lower-case letters) and that are not upper-case letters. If @@ -318,10 +322,6 @@ type driver struct { // policies is for new Azure Client Pipeline execution. // Added as a member to the struct to allow injection for testing. policies []policy.Policy - - // azureCredentials keeps track if we have already loaded an Azure - // credentials token when using UAMI for managed Azure on HCP. - azureCredentials sync.Map } // NewDriver creates a new storage driver for Azure Blob Storage. @@ -334,7 +334,7 @@ func NewDriver(ctx context.Context, c *imageregistryv1.ImageRegistryConfigStorag } func (d *driver) newAzClient(cfg *Azure, environment autorestazure.Environment, tagset map[string]*string) (*azureclient.Client, error) { - client, err := azureclient.New(&azureclient.Options{ + clientOptions := &azureclient.Options{ Environment: environment, TenantID: cfg.TenantID, ClientID: cfg.ClientID, @@ -343,10 +343,19 @@ func (d *driver) newAzClient(cfg *Azure, environment autorestazure.Environment, SubscriptionID: cfg.SubscriptionID, TagSet: tagset, Policies: d.policies, - }) + } + + if cred, ok, err := d.ensureUAMICredentials(d.Context, environment); err != nil { + return nil, err + } else if ok { + clientOptions.Creds = cred + } + + client, err := azureclient.New(clientOptions) if err != nil { return nil, err } + return client, nil } @@ -381,25 +390,10 @@ func (d *driver) storageAccountsClient(cfg *Azure, environment autorestazure.Env // UserAssignedIdentityCredentials is specifically for managed Azure HCP userAssignedIdentityCredentialsFilePath := os.Getenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH") if userAssignedIdentityCredentialsFilePath != "" { - var ok bool - - // We need to only store the Azure credentials once and reuse them after that. - storedCreds, found := d.azureCredentials.Load(userAssignedIdentityCredentialsFilePath) - if !found { - klog.V(2).Info("Using UserAssignedIdentityCredentials for Azure authentication for managed Azure HCP") - clientOptions := azcore.ClientOptions{ - Cloud: cloudConfig, - } - cred, err = dataplane.NewUserAssignedIdentityCredential(context.Background(), userAssignedIdentityCredentialsFilePath, dataplane.WithClientOpts(clientOptions)) - if err != nil { - return storage.AccountsClient{}, err - } - d.azureCredentials.Store(azureCredentialsKey, cred) - } else { - cred, ok = storedCreds.(azcore.TokenCredential) - if !ok { - return storage.AccountsClient{}, fmt.Errorf("expected %T to be a TokenCredential", storedCreds) - } + if c, ok, err := d.ensureUAMICredentials(d.Context, environment); err != nil { + return storage.AccountsClient{}, err + } else if ok { + cred = c } } else if strings.TrimSpace(cfg.ClientSecret) == "" { options := azidentity.WorkloadIdentityCredentialOptions{ @@ -1237,14 +1231,30 @@ func (d *driver) RemoveStorage(cr *imageregistryv1.Config) (retry bool, err erro } if d.Config.NetworkAccess != nil && d.Config.NetworkAccess.Internal != nil && d.Config.NetworkAccess.Internal.PrivateEndpointName != "" { - azclient, err := azureclient.New(&azureclient.Options{ + clientOptions := &azureclient.Options{ Environment: environment, TenantID: cfg.TenantID, ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, FederatedTokenFile: cfg.FederatedTokenFile, SubscriptionID: cfg.SubscriptionID, - }) + } + + if cred, ok, err := d.ensureUAMICredentials(d.Context, environment); err != nil { + util.UpdateCondition( + cr, + defaults.StorageExists, + operatorapiv1.ConditionUnknown, + storageExistsReasonAzureError, + fmt.Sprintf("Unable to get azure client: %s", err), + ) + return false, err + } else if ok { + klog.V(2).Infof("Using cached UAMI credential for RemoveStorage client") + clientOptions.Creds = cred + } + + azclient, err := azureclient.New(clientOptions) if err != nil { util.UpdateCondition( cr, @@ -1320,3 +1330,49 @@ func (d *driver) RemoveStorage(cr *imageregistryv1.Config) (retry bool, err erro func (d *driver) ID() string { return d.Config.Container } + +// ensureUAMICredentials obtains and caches an Azure TokenCredential using a +// User Assigned Managed Identity (UAMI). +// +// If MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH is unset, it returns (nil, false, nil). +// When set, it loads a credential from a process-wide cache or creates one using the +// provided Azure environment endpoints, stores it, and returns it. +// +// The bool result is true when a UAMI credential is available. An error is returned if +// credential creation fails or a cached value has an unexpected type. +// +// ctx controls cancellation of credential creation. env supplies Azure endpoints. +func (d *driver) ensureUAMICredentials(ctx context.Context, env autorestazure.Environment) (azcore.TokenCredential, bool, error) { + if os.Getenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH") == "" { + return nil, false, nil + } + if stored, ok := globalAzureCredentials.Load(azureCredentialsKey); ok { + if cred, ok := stored.(azcore.TokenCredential); ok { + klog.V(2).Infof("Loaded UAMI credentials from cache") + return cred, true, nil + } + return nil, false, fmt.Errorf("expected cached credential to be azcore.TokenCredential") + } + cloudConfig := cloud.Configuration{ + ActiveDirectoryAuthorityHost: env.ActiveDirectoryEndpoint, + Services: map[cloud.ServiceName]cloud.ServiceConfiguration{ + cloud.ResourceManager: { + Audience: env.TokenAudience, + Endpoint: env.ResourceManagerEndpoint, + }, + }, + } + cred, err := dataplane.NewUserAssignedIdentityCredential( + ctx, + os.Getenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH"), + dataplane.WithClientOpts(azcore.ClientOptions{Cloud: cloudConfig}), + ) + if err != nil { + return nil, false, err + } + if actual, loaded := globalAzureCredentials.LoadOrStore(azureCredentialsKey, cred); loaded { + return actual.(azcore.TokenCredential), true, nil + } + klog.V(2).Infof("Storing UAMI credentials to global cache") + return cred, true, nil +} diff --git a/pkg/storage/azure/azure_test.go b/pkg/storage/azure/azure_test.go index de38fe7e3..6c32c55aa 100644 --- a/pkg/storage/azure/azure_test.go +++ b/pkg/storage/azure/azure_test.go @@ -12,10 +12,13 @@ import ( "regexp" "strings" "testing" + "time" "github.com/Azure/azure-pipeline-go/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/go-autorest/autorest" + autorestazure "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/mocks" "github.com/Azure/go-autorest/autorest/to" "github.com/google/go-cmp/cmp" @@ -1416,3 +1419,145 @@ func Test_storageManagementStateNonAzureStackHub(t *testing.T) { }) } } + +// fakeTokenCredential implements azcore.TokenCredential for testing +type fakeTokenCredential struct { + id string +} + +func (f *fakeTokenCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + return azcore.AccessToken{ + Token: "fake-token-" + f.id, + ExpiresOn: time.Now().Add(time.Hour), + }, nil +} + +// resetGlobalAzureCredentials clears the global cache between tests +func resetGlobalAzureCredentials() { + globalAzureCredentials.Range(func(key, value any) bool { + globalAzureCredentials.Delete(key) + return true + }) +} + +func TestEnsureUAMICredentials(t *testing.T) { + for _, tt := range []struct { + name string + envValue string + cacheSetup func() + expectedCred azcore.TokenCredential + expectedOk bool + expectedErr string + }{ + { + name: "environment variable not set", + envValue: "", + expectedCred: nil, + expectedOk: false, + expectedErr: "", + }, + { + name: "credential loaded from cache", + envValue: "/path/to/creds.json", + cacheSetup: func() { + resetGlobalAzureCredentials() + fakeCred := &fakeTokenCredential{id: "cached"} + globalAzureCredentials.Store(azureCredentialsKey, fakeCred) + }, + expectedCred: &fakeTokenCredential{id: "cached"}, + expectedOk: true, + expectedErr: "", + }, + { + name: "invalid cached credential type", + envValue: "/path/to/creds.json", + cacheSetup: func() { + resetGlobalAzureCredentials() + // Store wrong type in cache + globalAzureCredentials.Store(azureCredentialsKey, "not-a-credential") + }, + expectedCred: nil, + expectedOk: false, + expectedErr: "expected cached credential to be azcore.TokenCredential", + }, + } { + t.Run(tt.name, func(t *testing.T) { + // Set up environment using t.Setenv (Go 1.17+) + if tt.envValue != "" { + t.Setenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH", tt.envValue) + } + + // Set up cache + if tt.cacheSetup != nil { + tt.cacheSetup() + } + + // Create driver and call the function + d := &driver{} + env := autorestazure.PublicCloud + cred, ok, err := d.ensureUAMICredentials(context.Background(), env) + + // Verify error + if tt.expectedErr != "" { + if err == nil || err.Error() != tt.expectedErr { + t.Errorf("expected error %q, got %v", tt.expectedErr, err) + } + } else if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // Verify ok result + if ok != tt.expectedOk { + t.Errorf("expected ok=%v, got %v", tt.expectedOk, ok) + } + + // Verify credential result + if tt.expectedCred != nil { + if cred == nil { + t.Errorf("expected credential, got nil") + } else { + // Check that we got the right credential by comparing the token + expectedToken, _ := tt.expectedCred.GetToken(context.Background(), policy.TokenRequestOptions{}) + actualToken, _ := cred.GetToken(context.Background(), policy.TokenRequestOptions{}) + if expectedToken.Token != actualToken.Token { + t.Errorf("expected credential with token %q, got %q", expectedToken.Token, actualToken.Token) + } + } + } else if cred != nil { + t.Errorf("expected nil credential, got %v", cred) + } + }) + } +} + +func TestEnsureUAMICredentials_CacheUsage(t *testing.T) { + // Set up environment using t.Setenv + t.Setenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH", "/path/to/creds.json") + + // Reset cache and add a credential + resetGlobalAzureCredentials() + fakeCred := &fakeTokenCredential{id: "test"} + globalAzureCredentials.Store(azureCredentialsKey, fakeCred) + + d := &driver{} + env := autorestazure.PublicCloud + + // First call should load from cache + cred1, ok1, err1 := d.ensureUAMICredentials(context.Background(), env) + if err1 != nil || !ok1 || cred1 == nil { + t.Fatalf("first call failed: err=%v ok=%v cred=%v", err1, ok1, cred1) + } + + // Second call should also load from cache + cred2, ok2, err2 := d.ensureUAMICredentials(context.Background(), env) + if err2 != nil || !ok2 || cred2 == nil { + t.Fatalf("second call failed: err=%v ok=%v cred=%v", err2, ok2, cred2) + } + + // Verify same credential instance returned + token1, _ := cred1.GetToken(context.Background(), policy.TokenRequestOptions{}) + token2, _ := cred2.GetToken(context.Background(), policy.TokenRequestOptions{}) + if token1.Token != token2.Token { + t.Errorf("expected same credential from cache, got different tokens: %q vs %q", token1.Token, token2.Token) + } +} diff --git a/pkg/storage/azure/azureclient/azureclient.go b/pkg/storage/azure/azureclient/azureclient.go index 9e8457b88..b64783fef 100644 --- a/pkg/storage/azure/azureclient/azureclient.go +++ b/pkg/storage/azure/azureclient/azureclient.go @@ -111,7 +111,7 @@ func (c *Client) getCreds(ctx context.Context) (azcore.TokenCredential, error) { var ok bool // We need to only store the Azure credentials once and reuse them after that. - storedCreds, found := c.azureCredentials.Load(userAssignedIdentityCredentialsFilePath) + storedCreds, found := c.azureCredentials.Load(azureCredentialsKey) if !found { klog.V(2).Info("Using UserAssignedIdentityCredentials for Azure authentication for managed Azure HCP") clientOptions := azcore.ClientOptions{