diff --git a/pkg/blob/blob.go b/pkg/blob/blob.go index 0fa01e9dc..08e257a3d 100644 --- a/pkg/blob/blob.go +++ b/pkg/blob/blob.go @@ -83,6 +83,7 @@ const ( storageSPNClientIDField = "azurestoragespnclientid" storageSPNTenantIDField = "azurestoragespntenantid" storageAuthTypeField = "azurestorageauthtype" + storageAuthTypeMSI = "msi" storageIdentityClientIDField = "azurestorageidentityclientid" storageIdentityObjectIDField = "azurestorageidentityobjectid" storageIdentityResourceIDField = "azurestorageidentityresourceid" @@ -635,7 +636,7 @@ func (d *Driver) GetAuthEnv(ctx context.Context, volumeID, protocol string, attr if spnTenantID != "" { storageSPNTenantID = spnTenantID } - if err != nil && strings.EqualFold(azureStorageAuthType, "msi") { + if err != nil && strings.EqualFold(azureStorageAuthType, storageAuthTypeMSI) { klog.V(2).Infof("ignore error(%v) since secret is optional for auth type(%s)", err, azureStorageAuthType) err = nil } @@ -708,6 +709,23 @@ func (d *Driver) GetAuthEnv(ctx context.Context, volumeID, protocol string, attr authEnv = append(authEnv, "AZURE_STORAGE_SPN_TENANT_ID="+storageSPNTenantID) } + if azureStorageAuthType == storageAuthTypeMSI { + // check whether authEnv contains AZURE_STORAGE_IDENTITY_ prefix + containsIdentityEnv := false + for _, env := range authEnv { + if strings.HasPrefix(env, "AZURE_STORAGE_IDENTITY_") { + klog.V(2).Infof("AZURE_STORAGE_IDENTITY_ is already set in authEnv, skip setting it again") + containsIdentityEnv = true + break + } + } + if !containsIdentityEnv && d.cloud != nil && d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID != "" { + klog.V(2).Infof("azureStorageAuthType is set to %s, add AZURE_STORAGE_IDENTITY_CLIENT_ID(%s) into authEnv", + azureStorageAuthType, d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID) + authEnv = append(authEnv, "AZURE_STORAGE_IDENTITY_CLIENT_ID="+d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID) + } + } + return rgName, accountName, accountKey, containerName, authEnv, err } diff --git a/pkg/blob/blob_test.go b/pkg/blob/blob_test.go index e1ef9ee47..64ed50be9 100644 --- a/pkg/blob/blob_test.go +++ b/pkg/blob/blob_test.go @@ -588,6 +588,54 @@ func TestGetAuthEnv(t *testing.T) { } }, }, + { + name: "valid request with MSIAuthTypeAddsIdentityEnv", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &storage.AccountRepo{} + d.cloud.Config.AzureAuthConfig = azclient.AzureAuthConfig{ + UserAssignedIdentityID: "unit-test-identity-id", + } + + attrib := map[string]string{ + subscriptionIDField: "subID", + resourceGroupField: "rg", + storageAccountField: "accountname", + storageAccountNameField: "accountname", + secretNameField: "secretName", + secretNamespaceField: "sNS", + containerNameField: "containername", + mountWithWITokenField: "false", + pvcNamespaceKey: "pvcNSKey", + getAccountKeyFromSecretField: "false", + storageAuthTypeField: storageAuthTypeMSI, + msiEndpointField: "msiEndpoint", + getLatestAccountKeyField: "true", + } + secret := make(map[string]string) + volumeID := "rg#f5713de20cde511e8ba4900#pvc-fuse-dynamic-17e43f84-f474-11e8-acd0-000d3a00df41" + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStorageAccountsClient := mock_accountclient.NewMockInterface(ctrl) + d.cloud.ComputeClientFactory = mock_azclient.NewMockClientFactory(ctrl) + d.cloud.ComputeClientFactory.(*mock_azclient.MockClientFactory).EXPECT().GetAccountClient().Return(mockStorageAccountsClient).AnyTimes() + s := "unit-test" + accountkey := armstorage.AccountKey{Value: &s} + list := []*armstorage.AccountKey{&accountkey} + mockStorageAccountsClient.EXPECT().ListKeys(gomock.Any(), gomock.Any(), gomock.Any()).Return(list, nil).AnyTimes() + d.cloud.ComputeClientFactory.(*mock_azclient.MockClientFactory).EXPECT().GetAccountClientForSub(gomock.Any()).Return(mockStorageAccountsClient, nil).AnyTimes() + _, _, _, _, authEnv, err := d.GetAuthEnv(context.TODO(), volumeID, "", attrib, secret) + assert.NoError(t, err) + found := false + for _, env := range authEnv { + if env == "AZURE_STORAGE_IDENTITY_CLIENT_ID=unit-test-identity-id" { + found = true + break + } + } + assert.True(t, found, "AZURE_STORAGE_IDENTITY_CLIENT_ID should be present in authEnv") + }, + }, { name: "invalid getLatestAccountKey value", testFunc: func(t *testing.T) {