Skip to content

OCPBUGS-59734: fix(azure): resolve credential caching issues around UAMI support #1238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 83 additions & 27 deletions pkg/storage/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ const (
azureCredentialsKey = "AzureCredentials"
)

// globalAzureCredentials caches User Assigned Managed Identity (UAMI) credentials across driver instances so that
// reconciles do not recreate the credential repeatedly.
var globalAzureCredentials sync.Map

// storageAccountInvalidCharRe is a regular expression for characters that
// cannot be used in Azure storage accounts names (i.e. that are not
// numbers nor lower-case letters) and that are not upper-case letters. If
Expand Down Expand Up @@ -318,10 +322,6 @@ type driver struct {
// policies is for new Azure Client Pipeline execution.
// Added as a member to the struct to allow injection for testing.
policies []policy.Policy

// azureCredentials keeps track if we have already loaded an Azure
// credentials token when using UAMI for managed Azure on HCP.
azureCredentials sync.Map
}

// NewDriver creates a new storage driver for Azure Blob Storage.
Expand All @@ -334,7 +334,7 @@ func NewDriver(ctx context.Context, c *imageregistryv1.ImageRegistryConfigStorag
}

func (d *driver) newAzClient(cfg *Azure, environment autorestazure.Environment, tagset map[string]*string) (*azureclient.Client, error) {
client, err := azureclient.New(&azureclient.Options{
clientOptions := &azureclient.Options{
Environment: environment,
TenantID: cfg.TenantID,
ClientID: cfg.ClientID,
Expand All @@ -343,10 +343,19 @@ func (d *driver) newAzClient(cfg *Azure, environment autorestazure.Environment,
SubscriptionID: cfg.SubscriptionID,
TagSet: tagset,
Policies: d.policies,
})
}

if cred, ok, err := d.ensureUAMICredentials(d.Context, environment); err != nil {
return nil, err
} else if ok {
clientOptions.Creds = cred
}

client, err := azureclient.New(clientOptions)
if err != nil {
return nil, err
}

return client, nil
}

Expand Down Expand Up @@ -381,25 +390,10 @@ func (d *driver) storageAccountsClient(cfg *Azure, environment autorestazure.Env
// UserAssignedIdentityCredentials is specifically for managed Azure HCP
userAssignedIdentityCredentialsFilePath := os.Getenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH")
if userAssignedIdentityCredentialsFilePath != "" {
var ok bool

// We need to only store the Azure credentials once and reuse them after that.
storedCreds, found := d.azureCredentials.Load(userAssignedIdentityCredentialsFilePath)
if !found {
klog.V(2).Info("Using UserAssignedIdentityCredentials for Azure authentication for managed Azure HCP")
clientOptions := azcore.ClientOptions{
Cloud: cloudConfig,
}
cred, err = dataplane.NewUserAssignedIdentityCredential(context.Background(), userAssignedIdentityCredentialsFilePath, dataplane.WithClientOpts(clientOptions))
if err != nil {
return storage.AccountsClient{}, err
}
d.azureCredentials.Store(azureCredentialsKey, cred)
} else {
cred, ok = storedCreds.(azcore.TokenCredential)
if !ok {
return storage.AccountsClient{}, fmt.Errorf("expected %T to be a TokenCredential", storedCreds)
}
if c, ok, err := d.ensureUAMICredentials(d.Context, environment); err != nil {
return storage.AccountsClient{}, err
} else if ok {
cred = c
}
} else if strings.TrimSpace(cfg.ClientSecret) == "" {
options := azidentity.WorkloadIdentityCredentialOptions{
Expand Down Expand Up @@ -1237,14 +1231,30 @@ func (d *driver) RemoveStorage(cr *imageregistryv1.Config) (retry bool, err erro
}

if d.Config.NetworkAccess != nil && d.Config.NetworkAccess.Internal != nil && d.Config.NetworkAccess.Internal.PrivateEndpointName != "" {
azclient, err := azureclient.New(&azureclient.Options{
clientOptions := &azureclient.Options{
Environment: environment,
TenantID: cfg.TenantID,
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
FederatedTokenFile: cfg.FederatedTokenFile,
SubscriptionID: cfg.SubscriptionID,
})
}

if cred, ok, err := d.ensureUAMICredentials(d.Context, environment); err != nil {
util.UpdateCondition(
cr,
defaults.StorageExists,
operatorapiv1.ConditionUnknown,
storageExistsReasonAzureError,
fmt.Sprintf("Unable to get azure client: %s", err),
)
return false, err
} else if ok {
klog.V(2).Infof("Using cached UAMI credential for RemoveStorage client")
clientOptions.Creds = cred
}

azclient, err := azureclient.New(clientOptions)
if err != nil {
util.UpdateCondition(
cr,
Expand Down Expand Up @@ -1320,3 +1330,49 @@ func (d *driver) RemoveStorage(cr *imageregistryv1.Config) (retry bool, err erro
func (d *driver) ID() string {
return d.Config.Container
}

// ensureUAMICredentials obtains and caches an Azure TokenCredential using a
// User Assigned Managed Identity (UAMI).
//
// If MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH is unset, it returns (nil, false, nil).
// When set, it loads a credential from a process-wide cache or creates one using the
// provided Azure environment endpoints, stores it, and returns it.
//
// The bool result is true when a UAMI credential is available. An error is returned if
// credential creation fails or a cached value has an unexpected type.
//
// ctx controls cancellation of credential creation. env supplies Azure endpoints.
func (d *driver) ensureUAMICredentials(ctx context.Context, env autorestazure.Environment) (azcore.TokenCredential, bool, error) {
if os.Getenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH") == "" {
return nil, false, nil
}
if stored, ok := globalAzureCredentials.Load(azureCredentialsKey); ok {
if cred, ok := stored.(azcore.TokenCredential); ok {
klog.V(2).Infof("Loaded UAMI credentials from cache")
return cred, true, nil
}
return nil, false, fmt.Errorf("expected cached credential to be azcore.TokenCredential")
}
cloudConfig := cloud.Configuration{
ActiveDirectoryAuthorityHost: env.ActiveDirectoryEndpoint,
Services: map[cloud.ServiceName]cloud.ServiceConfiguration{
cloud.ResourceManager: {
Audience: env.TokenAudience,
Endpoint: env.ResourceManagerEndpoint,
},
},
}
cred, err := dataplane.NewUserAssignedIdentityCredential(
ctx,
os.Getenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH"),
dataplane.WithClientOpts(azcore.ClientOptions{Cloud: cloudConfig}),
)
if err != nil {
return nil, false, err
}
if actual, loaded := globalAzureCredentials.LoadOrStore(azureCredentialsKey, cred); loaded {
return actual.(azcore.TokenCredential), true, nil
}
klog.V(2).Infof("Storing UAMI credentials to global cache")
return cred, true, nil
}
145 changes: 145 additions & 0 deletions pkg/storage/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ import (
"regexp"
"strings"
"testing"
"time"

"github.com/Azure/azure-pipeline-go/pipeline"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/go-autorest/autorest"
autorestazure "github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/mocks"
"github.com/Azure/go-autorest/autorest/to"
"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -1416,3 +1419,145 @@ func Test_storageManagementStateNonAzureStackHub(t *testing.T) {
})
}
}

// fakeTokenCredential implements azcore.TokenCredential for testing
type fakeTokenCredential struct {
id string
}

func (f *fakeTokenCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
return azcore.AccessToken{
Token: "fake-token-" + f.id,
ExpiresOn: time.Now().Add(time.Hour),
}, nil
}

// resetGlobalAzureCredentials clears the global cache between tests
func resetGlobalAzureCredentials() {
globalAzureCredentials.Range(func(key, value any) bool {
globalAzureCredentials.Delete(key)
return true
})
}

func TestEnsureUAMICredentials(t *testing.T) {
for _, tt := range []struct {
name string
envValue string
cacheSetup func()
expectedCred azcore.TokenCredential
expectedOk bool
expectedErr string
}{
{
name: "environment variable not set",
envValue: "",
expectedCred: nil,
expectedOk: false,
expectedErr: "",
},
{
name: "credential loaded from cache",
envValue: "/path/to/creds.json",
cacheSetup: func() {
resetGlobalAzureCredentials()
fakeCred := &fakeTokenCredential{id: "cached"}
globalAzureCredentials.Store(azureCredentialsKey, fakeCred)
},
expectedCred: &fakeTokenCredential{id: "cached"},
expectedOk: true,
expectedErr: "",
},
{
name: "invalid cached credential type",
envValue: "/path/to/creds.json",
cacheSetup: func() {
resetGlobalAzureCredentials()
// Store wrong type in cache
globalAzureCredentials.Store(azureCredentialsKey, "not-a-credential")
},
expectedCred: nil,
expectedOk: false,
expectedErr: "expected cached credential to be azcore.TokenCredential",
},
} {
t.Run(tt.name, func(t *testing.T) {
// Set up environment using t.Setenv (Go 1.17+)
if tt.envValue != "" {
t.Setenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH", tt.envValue)
}

// Set up cache
if tt.cacheSetup != nil {
tt.cacheSetup()
}

// Create driver and call the function
d := &driver{}
env := autorestazure.PublicCloud
cred, ok, err := d.ensureUAMICredentials(context.Background(), env)

// Verify error
if tt.expectedErr != "" {
if err == nil || err.Error() != tt.expectedErr {
t.Errorf("expected error %q, got %v", tt.expectedErr, err)
}
} else if err != nil {
t.Errorf("unexpected error: %v", err)
}

// Verify ok result
if ok != tt.expectedOk {
t.Errorf("expected ok=%v, got %v", tt.expectedOk, ok)
}

// Verify credential result
if tt.expectedCred != nil {
if cred == nil {
t.Errorf("expected credential, got nil")
} else {
// Check that we got the right credential by comparing the token
expectedToken, _ := tt.expectedCred.GetToken(context.Background(), policy.TokenRequestOptions{})
actualToken, _ := cred.GetToken(context.Background(), policy.TokenRequestOptions{})
if expectedToken.Token != actualToken.Token {
t.Errorf("expected credential with token %q, got %q", expectedToken.Token, actualToken.Token)
}
}
} else if cred != nil {
t.Errorf("expected nil credential, got %v", cred)
}
})
}
}

func TestEnsureUAMICredentials_CacheUsage(t *testing.T) {
// Set up environment using t.Setenv
t.Setenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH", "/path/to/creds.json")

// Reset cache and add a credential
resetGlobalAzureCredentials()
fakeCred := &fakeTokenCredential{id: "test"}
globalAzureCredentials.Store(azureCredentialsKey, fakeCred)

d := &driver{}
env := autorestazure.PublicCloud

// First call should load from cache
cred1, ok1, err1 := d.ensureUAMICredentials(context.Background(), env)
if err1 != nil || !ok1 || cred1 == nil {
t.Fatalf("first call failed: err=%v ok=%v cred=%v", err1, ok1, cred1)
}

// Second call should also load from cache
cred2, ok2, err2 := d.ensureUAMICredentials(context.Background(), env)
if err2 != nil || !ok2 || cred2 == nil {
t.Fatalf("second call failed: err=%v ok=%v cred=%v", err2, ok2, cred2)
}

// Verify same credential instance returned
token1, _ := cred1.GetToken(context.Background(), policy.TokenRequestOptions{})
token2, _ := cred2.GetToken(context.Background(), policy.TokenRequestOptions{})
if token1.Token != token2.Token {
t.Errorf("expected same credential from cache, got different tokens: %q vs %q", token1.Token, token2.Token)
}
}
2 changes: 1 addition & 1 deletion pkg/storage/azure/azureclient/azureclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (c *Client) getCreds(ctx context.Context) (azcore.TokenCredential, error) {
var ok bool

// We need to only store the Azure credentials once and reuse them after that.
storedCreds, found := c.azureCredentials.Load(userAssignedIdentityCredentialsFilePath)
storedCreds, found := c.azureCredentials.Load(azureCredentialsKey)
if !found {
klog.V(2).Info("Using UserAssignedIdentityCredentials for Azure authentication for managed Azure HCP")
clientOptions := azcore.ClientOptions{
Expand Down