Skip to content

Commit b958d28

Browse files
committed
feat(azure): add ensureUAMICredentials function with comprehensive tests
- Add ensureUAMICredentials function to obtain and cache Azure TokenCredential using User Assigned Managed Identity (UAMI) - Function loads credentials from global cache or creates new ones with proper environment configuration - Add comprehensive unit tests covering environment variable handling, cache behavior, and error scenarios - Tests use table-driven pattern following codebase conventions and t.Setenv for proper environment handling - Update Azure client instantiation to use UAMI credentials when available
1 parent 5969417 commit b958d28

File tree

2 files changed

+228
-27
lines changed

2 files changed

+228
-27
lines changed

pkg/storage/azure/azure.go

Lines changed: 83 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ const (
5757
azureCredentialsKey = "AzureCredentials"
5858
)
5959

60+
// globalAzureCredentials caches User Assigned Managed Identity (UAMI) credentials across driver instances so that
61+
// reconciles do not recreate the credential repeatedly.
62+
var globalAzureCredentials sync.Map
63+
6064
// storageAccountInvalidCharRe is a regular expression for characters that
6165
// cannot be used in Azure storage accounts names (i.e. that are not
6266
// numbers nor lower-case letters) and that are not upper-case letters. If
@@ -318,10 +322,6 @@ type driver struct {
318322
// policies is for new Azure Client Pipeline execution.
319323
// Added as a member to the struct to allow injection for testing.
320324
policies []policy.Policy
321-
322-
// azureCredentials keeps track if we have already loaded an Azure
323-
// credentials token when using UAMI for managed Azure on HCP.
324-
azureCredentials sync.Map
325325
}
326326

327327
// NewDriver creates a new storage driver for Azure Blob Storage.
@@ -334,7 +334,7 @@ func NewDriver(ctx context.Context, c *imageregistryv1.ImageRegistryConfigStorag
334334
}
335335

336336
func (d *driver) newAzClient(cfg *Azure, environment autorestazure.Environment, tagset map[string]*string) (*azureclient.Client, error) {
337-
client, err := azureclient.New(&azureclient.Options{
337+
clientOptions := &azureclient.Options{
338338
Environment: environment,
339339
TenantID: cfg.TenantID,
340340
ClientID: cfg.ClientID,
@@ -343,10 +343,19 @@ func (d *driver) newAzClient(cfg *Azure, environment autorestazure.Environment,
343343
SubscriptionID: cfg.SubscriptionID,
344344
TagSet: tagset,
345345
Policies: d.policies,
346-
})
346+
}
347+
348+
if cred, ok, err := d.ensureUAMICredentials(d.Context, environment); err != nil {
349+
return nil, err
350+
} else if ok {
351+
clientOptions.Creds = cred
352+
}
353+
354+
client, err := azureclient.New(clientOptions)
347355
if err != nil {
348356
return nil, err
349357
}
358+
350359
return client, nil
351360
}
352361

@@ -381,25 +390,10 @@ func (d *driver) storageAccountsClient(cfg *Azure, environment autorestazure.Env
381390
// UserAssignedIdentityCredentials is specifically for managed Azure HCP
382391
userAssignedIdentityCredentialsFilePath := os.Getenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH")
383392
if userAssignedIdentityCredentialsFilePath != "" {
384-
var ok bool
385-
386-
// We need to only store the Azure credentials once and reuse them after that.
387-
storedCreds, found := d.azureCredentials.Load(azureCredentialsKey)
388-
if !found {
389-
klog.V(2).Info("Using UserAssignedIdentityCredentials for Azure authentication for managed Azure HCP")
390-
clientOptions := azcore.ClientOptions{
391-
Cloud: cloudConfig,
392-
}
393-
cred, err = dataplane.NewUserAssignedIdentityCredential(context.Background(), userAssignedIdentityCredentialsFilePath, dataplane.WithClientOpts(clientOptions))
394-
if err != nil {
395-
return storage.AccountsClient{}, err
396-
}
397-
d.azureCredentials.Store(azureCredentialsKey, cred)
398-
} else {
399-
cred, ok = storedCreds.(azcore.TokenCredential)
400-
if !ok {
401-
return storage.AccountsClient{}, fmt.Errorf("expected %T to be a TokenCredential", storedCreds)
402-
}
393+
if c, ok, err := d.ensureUAMICredentials(d.Context, environment); err != nil {
394+
return storage.AccountsClient{}, err
395+
} else if ok {
396+
cred = c
403397
}
404398
} else if strings.TrimSpace(cfg.ClientSecret) == "" {
405399
options := azidentity.WorkloadIdentityCredentialOptions{
@@ -1237,14 +1231,30 @@ func (d *driver) RemoveStorage(cr *imageregistryv1.Config) (retry bool, err erro
12371231
}
12381232

12391233
if d.Config.NetworkAccess != nil && d.Config.NetworkAccess.Internal != nil && d.Config.NetworkAccess.Internal.PrivateEndpointName != "" {
1240-
azclient, err := azureclient.New(&azureclient.Options{
1234+
clientOptions := &azureclient.Options{
12411235
Environment: environment,
12421236
TenantID: cfg.TenantID,
12431237
ClientID: cfg.ClientID,
12441238
ClientSecret: cfg.ClientSecret,
12451239
FederatedTokenFile: cfg.FederatedTokenFile,
12461240
SubscriptionID: cfg.SubscriptionID,
1247-
})
1241+
}
1242+
1243+
if cred, ok, err := d.ensureUAMICredentials(d.Context, environment); err != nil {
1244+
util.UpdateCondition(
1245+
cr,
1246+
defaults.StorageExists,
1247+
operatorapiv1.ConditionUnknown,
1248+
storageExistsReasonAzureError,
1249+
fmt.Sprintf("Unable to get azure client: %s", err),
1250+
)
1251+
return false, err
1252+
} else if ok {
1253+
klog.V(2).Infof("Using cached UAMI credential for RemoveStorage client")
1254+
clientOptions.Creds = cred
1255+
}
1256+
1257+
azclient, err := azureclient.New(clientOptions)
12481258
if err != nil {
12491259
util.UpdateCondition(
12501260
cr,
@@ -1320,3 +1330,49 @@ func (d *driver) RemoveStorage(cr *imageregistryv1.Config) (retry bool, err erro
13201330
func (d *driver) ID() string {
13211331
return d.Config.Container
13221332
}
1333+
1334+
// ensureUAMICredentials obtains and caches an Azure TokenCredential using a
1335+
// User Assigned Managed Identity (UAMI).
1336+
//
1337+
// If MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH is unset, it returns (nil, false, nil).
1338+
// When set, it loads a credential from a process-wide cache or creates one using the
1339+
// provided Azure environment endpoints, stores it, and returns it.
1340+
//
1341+
// The bool result is true when a UAMI credential is available. An error is returned if
1342+
// credential creation fails or a cached value has an unexpected type.
1343+
//
1344+
// ctx controls cancellation of credential creation. env supplies Azure endpoints.
1345+
func (d *driver) ensureUAMICredentials(ctx context.Context, env autorestazure.Environment) (azcore.TokenCredential, bool, error) {
1346+
if stored, ok := globalAzureCredentials.Load(azureCredentialsKey); ok {
1347+
if cred, ok := stored.(azcore.TokenCredential); ok {
1348+
klog.V(2).Infof("Loaded UAMI credentials from cache")
1349+
return cred, true, nil
1350+
}
1351+
return nil, false, fmt.Errorf("expected cached credential to be azcore.TokenCredential")
1352+
}
1353+
if os.Getenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH") == "" {
1354+
return nil, false, nil
1355+
}
1356+
cloudConfig := cloud.Configuration{
1357+
ActiveDirectoryAuthorityHost: env.ActiveDirectoryEndpoint,
1358+
Services: map[cloud.ServiceName]cloud.ServiceConfiguration{
1359+
cloud.ResourceManager: {
1360+
Audience: env.TokenAudience,
1361+
Endpoint: env.ResourceManagerEndpoint,
1362+
},
1363+
},
1364+
}
1365+
cred, err := dataplane.NewUserAssignedIdentityCredential(
1366+
ctx,
1367+
os.Getenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH"),
1368+
dataplane.WithClientOpts(azcore.ClientOptions{Cloud: cloudConfig}),
1369+
)
1370+
if err != nil {
1371+
return nil, false, err
1372+
}
1373+
if actual, loaded := globalAzureCredentials.LoadOrStore(azureCredentialsKey, cred); loaded {
1374+
return actual.(azcore.TokenCredential), true, nil
1375+
}
1376+
klog.V(2).Infof("Storing UAMI credentials to global cache")
1377+
return cred, true, nil
1378+
}

pkg/storage/azure/azure_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ import (
1212
"regexp"
1313
"strings"
1414
"testing"
15+
"time"
1516

1617
"github.com/Azure/azure-pipeline-go/pipeline"
18+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
1719
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
1820
"github.com/Azure/go-autorest/autorest"
21+
autorestazure "github.com/Azure/go-autorest/autorest/azure"
1922
"github.com/Azure/go-autorest/autorest/mocks"
2023
"github.com/Azure/go-autorest/autorest/to"
2124
"github.com/google/go-cmp/cmp"
@@ -1416,3 +1419,145 @@ func Test_storageManagementStateNonAzureStackHub(t *testing.T) {
14161419
})
14171420
}
14181421
}
1422+
1423+
// fakeTokenCredential implements azcore.TokenCredential for testing
1424+
type fakeTokenCredential struct {
1425+
id string
1426+
}
1427+
1428+
func (f *fakeTokenCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
1429+
return azcore.AccessToken{
1430+
Token: "fake-token-" + f.id,
1431+
ExpiresOn: time.Now().Add(time.Hour),
1432+
}, nil
1433+
}
1434+
1435+
// resetGlobalAzureCredentials clears the global cache between tests
1436+
func resetGlobalAzureCredentials() {
1437+
globalAzureCredentials.Range(func(key, value any) bool {
1438+
globalAzureCredentials.Delete(key)
1439+
return true
1440+
})
1441+
}
1442+
1443+
func TestEnsureUAMICredentials(t *testing.T) {
1444+
for _, tt := range []struct {
1445+
name string
1446+
envValue string
1447+
cacheSetup func()
1448+
expectedCred azcore.TokenCredential
1449+
expectedOk bool
1450+
expectedErr string
1451+
}{
1452+
{
1453+
name: "environment variable not set",
1454+
envValue: "",
1455+
expectedCred: nil,
1456+
expectedOk: false,
1457+
expectedErr: "",
1458+
},
1459+
{
1460+
name: "credential loaded from cache",
1461+
envValue: "/path/to/creds.json",
1462+
cacheSetup: func() {
1463+
resetGlobalAzureCredentials()
1464+
fakeCred := &fakeTokenCredential{id: "cached"}
1465+
globalAzureCredentials.Store(azureCredentialsKey, fakeCred)
1466+
},
1467+
expectedCred: &fakeTokenCredential{id: "cached"},
1468+
expectedOk: true,
1469+
expectedErr: "",
1470+
},
1471+
{
1472+
name: "invalid cached credential type",
1473+
envValue: "/path/to/creds.json",
1474+
cacheSetup: func() {
1475+
resetGlobalAzureCredentials()
1476+
// Store wrong type in cache
1477+
globalAzureCredentials.Store(azureCredentialsKey, "not-a-credential")
1478+
},
1479+
expectedCred: nil,
1480+
expectedOk: false,
1481+
expectedErr: "expected cached credential to be azcore.TokenCredential",
1482+
},
1483+
} {
1484+
t.Run(tt.name, func(t *testing.T) {
1485+
// Set up environment using t.Setenv (Go 1.17+)
1486+
if tt.envValue != "" {
1487+
t.Setenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH", tt.envValue)
1488+
}
1489+
1490+
// Set up cache
1491+
if tt.cacheSetup != nil {
1492+
tt.cacheSetup()
1493+
}
1494+
1495+
// Create driver and call the function
1496+
d := &driver{}
1497+
env := autorestazure.PublicCloud
1498+
cred, ok, err := d.ensureUAMICredentials(context.Background(), env)
1499+
1500+
// Verify error
1501+
if tt.expectedErr != "" {
1502+
if err == nil || err.Error() != tt.expectedErr {
1503+
t.Errorf("expected error %q, got %v", tt.expectedErr, err)
1504+
}
1505+
} else if err != nil {
1506+
t.Errorf("unexpected error: %v", err)
1507+
}
1508+
1509+
// Verify ok result
1510+
if ok != tt.expectedOk {
1511+
t.Errorf("expected ok=%v, got %v", tt.expectedOk, ok)
1512+
}
1513+
1514+
// Verify credential result
1515+
if tt.expectedCred != nil {
1516+
if cred == nil {
1517+
t.Errorf("expected credential, got nil")
1518+
} else {
1519+
// Check that we got the right credential by comparing the token
1520+
expectedToken, _ := tt.expectedCred.GetToken(context.Background(), policy.TokenRequestOptions{})
1521+
actualToken, _ := cred.GetToken(context.Background(), policy.TokenRequestOptions{})
1522+
if expectedToken.Token != actualToken.Token {
1523+
t.Errorf("expected credential with token %q, got %q", expectedToken.Token, actualToken.Token)
1524+
}
1525+
}
1526+
} else if cred != nil {
1527+
t.Errorf("expected nil credential, got %v", cred)
1528+
}
1529+
})
1530+
}
1531+
}
1532+
1533+
func TestEnsureUAMICredentials_CacheUsage(t *testing.T) {
1534+
// Set up environment using t.Setenv
1535+
t.Setenv("MANAGED_AZURE_HCP_CREDENTIALS_FILE_PATH", "/path/to/creds.json")
1536+
1537+
// Reset cache and add a credential
1538+
resetGlobalAzureCredentials()
1539+
fakeCred := &fakeTokenCredential{id: "test"}
1540+
globalAzureCredentials.Store(azureCredentialsKey, fakeCred)
1541+
1542+
d := &driver{}
1543+
env := autorestazure.PublicCloud
1544+
1545+
// First call should load from cache
1546+
cred1, ok1, err1 := d.ensureUAMICredentials(context.Background(), env)
1547+
if err1 != nil || !ok1 || cred1 == nil {
1548+
t.Fatalf("first call failed: err=%v ok=%v cred=%v", err1, ok1, cred1)
1549+
}
1550+
1551+
// Second call should also load from cache
1552+
cred2, ok2, err2 := d.ensureUAMICredentials(context.Background(), env)
1553+
if err2 != nil || !ok2 || cred2 == nil {
1554+
t.Fatalf("second call failed: err=%v ok=%v cred=%v", err2, ok2, cred2)
1555+
}
1556+
1557+
// Verify same credential instance returned
1558+
token1, _ := cred1.GetToken(context.Background(), policy.TokenRequestOptions{})
1559+
token2, _ := cred2.GetToken(context.Background(), policy.TokenRequestOptions{})
1560+
if token1.Token != token2.Token {
1561+
t.Errorf("expected same credential from cache, got different tokens: %q vs %q", token1.Token, token2.Token)
1562+
}
1563+
}

0 commit comments

Comments
 (0)