diff --git a/pkg/controller/controller_test.go b/pkg/controller/controller_test.go index 4e6dc8d41..239fef686 100644 --- a/pkg/controller/controller_test.go +++ b/pkg/controller/controller_test.go @@ -211,7 +211,7 @@ func TestController(t *testing.T) { disableVolumeInUseErrorHandler: true, }, } { - client := csi.NewMockClient("mock", test.NodeResize, true, false, true, true, false) + client := csi.NewMockClient("mock", test.NodeResize, true, false, true, true) driverName, _ := client.GetDriverName(context.TODO()) var expectedCap resource.Quantity @@ -378,7 +378,7 @@ func TestResizePVC(t *testing.T) { }, } { t.Run(test.Name, func(t *testing.T) { - client := csi.NewMockClient("mock", test.NodeResize, true, false, true, true, false) + client := csi.NewMockClient("mock", test.NodeResize, true, false, true, true) if test.expansionError != nil { client.SetExpansionError(test.expansionError) } diff --git a/pkg/controller/expand_and_recover_test.go b/pkg/controller/expand_and_recover_test.go index 8d86d21e4..625359150 100644 --- a/pkg/controller/expand_and_recover_test.go +++ b/pkg/controller/expand_and_recover_test.go @@ -159,7 +159,7 @@ func TestExpandAndRecover(t *testing.T) { test := tests[i] t.Run(test.name, func(t *testing.T) { featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.RecoverVolumeExpansionFailure, true) - client := csi.NewMockClient("foo", !test.disableNodeExpansion, !test.disableControllerExpansion, false, true, true, false) + client := csi.NewMockClient("foo", !test.disableNodeExpansion, !test.disableControllerExpansion, false, true, true) driverName, _ := client.GetDriverName(context.TODO()) if test.expansionError != nil { client.SetExpansionError(test.expansionError) diff --git a/pkg/controller/resize_status_test.go b/pkg/controller/resize_status_test.go index 7e13bfea6..67afc6844 100644 --- a/pkg/controller/resize_status_test.go +++ b/pkg/controller/resize_status_test.go @@ -77,7 +77,7 @@ func TestResizeFunctions(t *testing.T) { tc := test t.Run(tc.name, func(t *testing.T) { featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.RecoverVolumeExpansionFailure, true) - client := csi.NewMockClient("foo", true, true, false, true, true, false) + client := csi.NewMockClient("foo", true, true, false, true, true) driverName, _ := client.GetDriverName(context.TODO()) pvc := test.pvc diff --git a/pkg/csi/mock_client.go b/pkg/csi/mock_client.go index 49c93c3e2..dc8a03b7e 100644 --- a/pkg/csi/mock_client.go +++ b/pkg/csi/mock_client.go @@ -3,6 +3,8 @@ package csi import ( "context" "fmt" + "maps" + "sync" "sync/atomic" "github.com/container-storage-interface/spec/lib/go/csi" @@ -16,7 +18,6 @@ func NewMockClient( supportsControllerModify bool, supportsPluginControllerService bool, supportsControllerSingleNodeMultiWriter bool, - supportsExtraModifyMetada bool, ) *MockClient { return &MockClient{ name: name, @@ -25,7 +26,7 @@ func NewMockClient( supportsControllerModify: supportsControllerModify, supportsPluginControllerService: supportsPluginControllerService, supportsControllerSingleNodeMultiWriter: supportsControllerSingleNodeMultiWriter, - extraModifyMetadata: supportsExtraModifyMetada, + modifiedParameters: make(map[string]string), } } @@ -43,7 +44,8 @@ type MockClient struct { checkMigratedLabel bool usedSecrets atomic.Pointer[map[string]string] usedCapability atomic.Pointer[csi.VolumeCapability] - extraModifyMetadata bool + modifyMu sync.Mutex + modifiedParameters map[string]string } func (c *MockClient) GetDriverName(context.Context) (string, error) { @@ -116,6 +118,12 @@ func (c *MockClient) GetModifyCount() int { return int(c.modifyCalled.Load()) } +func (c *MockClient) GetModifiedParameters() map[string]string { + c.modifyMu.Lock() + defer c.modifyMu.Unlock() + return maps.Clone(c.modifiedParameters) +} + func (c *MockClient) GetCapability() *csi.VolumeCapability { return c.usedCapability.Load() } @@ -138,5 +146,8 @@ func (c *MockClient) Modify( if c.modifyError != nil { return c.modifyError } + c.modifyMu.Lock() + defer c.modifyMu.Unlock() + maps.Copy(c.modifiedParameters, mutableParameters) return nil } diff --git a/pkg/modifier/csi_modifier_test.go b/pkg/modifier/csi_modifier_test.go index de68df9f2..179eb0341 100644 --- a/pkg/modifier/csi_modifier_test.go +++ b/pkg/modifier/csi_modifier_test.go @@ -28,7 +28,7 @@ func TestNewModifier(t *testing.T) { SupportsControllerModify: false, }, } { - client := csi.NewMockClient("mock", false, false, c.SupportsControllerModify, false, false, false) + client := csi.NewMockClient("mock", false, false, c.SupportsControllerModify, false, false) driverName := "mock-driver" k8sClient, informerFactory := fakeK8s() _, err := NewModifierFromClient(client, 0, k8sClient, informerFactory, false, driverName) diff --git a/pkg/modifycontroller/controller_test.go b/pkg/modifycontroller/controller_test.go index 3acd842e1..f78a8c951 100644 --- a/pkg/modifycontroller/controller_test.go +++ b/pkg/modifycontroller/controller_test.go @@ -65,7 +65,7 @@ func TestController(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Setup - client := csi.NewMockClient(testDriverName, true, true, true, true, true, false) + client := csi.NewMockClient(testDriverName, true, true, true, true, true) initialObjects := []runtime.Object{test.pvc, test.pv, testVacObject, targetVacObject} ctrlInstance := setupFakeK8sEnvironment(t, client, initialObjects) @@ -116,7 +116,7 @@ func TestModifyPVC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - client := csi.NewMockClient(testDriverName, true, true, true, true, true, false) + client := csi.NewMockClient(testDriverName, true, true, true, true, true) if test.modifyFailure { client.SetModifyError(fmt.Errorf("fake modification error")) } @@ -217,7 +217,7 @@ func TestSyncPVC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - client := csi.NewMockClient(testDriverName, true, true, true, true, true, false) + client := csi.NewMockClient(testDriverName, true, true, true, true, true) initialObjects := []runtime.Object{test.pvc, test.pv, testVacObject, targetVacObject} ctrlInstance := setupFakeK8sEnvironment(t, client, initialObjects) @@ -277,7 +277,7 @@ func TestInfeasibleRetry(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Setup - client := csi.NewMockClient(testDriverName, true, true, true, true, true, false) + client := csi.NewMockClient(testDriverName, true, true, true, true, true) if test.csiModifyError != nil { client.SetModifyError(test.csiModifyError) } diff --git a/pkg/modifycontroller/modify_status_test.go b/pkg/modifycontroller/modify_status_test.go index 50c7de27a..ed1e2143e 100644 --- a/pkg/modifycontroller/modify_status_test.go +++ b/pkg/modifycontroller/modify_status_test.go @@ -104,7 +104,7 @@ func TestMarkControllerModifyVolumeStatus(t *testing.T) { tc := test t.Run(tc.name, func(t *testing.T) { featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.VolumeAttributesClass, true) - client := csi.NewMockClient("foo", true, true, true, true, true, false) + client := csi.NewMockClient("foo", true, true, true, true, true) driverName, _ := client.GetDriverName(context.TODO()) pvc := test.pvc @@ -164,7 +164,7 @@ func TestUpdateConditionBasedOnError(t *testing.T) { tc := test t.Run(tc.name, func(t *testing.T) { featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.VolumeAttributesClass, true) - client := csi.NewMockClient("foo", true, true, true, true, true, false) + client := csi.NewMockClient("foo", true, true, true, true, true) driverName, _ := client.GetDriverName(context.TODO()) pvc := test.pvc @@ -233,7 +233,7 @@ func TestMarkControllerModifyVolumeCompleted(t *testing.T) { tc := test t.Run(tc.name, func(t *testing.T) { featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.VolumeAttributesClass, true) - client := csi.NewMockClient("foo", true, true, true, true, true, false) + client := csi.NewMockClient("foo", true, true, true, true, true) driverName, _ := client.GetDriverName(context.TODO()) var initialObjects []runtime.Object @@ -295,7 +295,7 @@ func TestRemovePVCFromModifyVolumeUncertainCache(t *testing.T) { tc := test t.Run(tc.name, func(t *testing.T) { featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.VolumeAttributesClass, true) - client := csi.NewMockClient("foo", true, true, true, true, true, false) + client := csi.NewMockClient("foo", true, true, true, true, true) driverName, _ := client.GetDriverName(context.TODO()) var initialObjects []runtime.Object diff --git a/pkg/modifycontroller/modify_volume.go b/pkg/modifycontroller/modify_volume.go index 7523ab624..8e6290cb7 100644 --- a/pkg/modifycontroller/modify_volume.go +++ b/pkg/modifycontroller/modify_volume.go @@ -18,6 +18,7 @@ package modifycontroller import ( "fmt" + "maps" "time" "github.com/kubernetes-csi/csi-lib-utils/slowset" @@ -161,12 +162,18 @@ func (ctrl *modifyController) callModifyVolumeOnPlugin( pvc *v1.PersistentVolumeClaim, pv *v1.PersistentVolume, vac *storagev1beta1.VolumeAttributesClass) (*v1.PersistentVolumeClaim, *v1.PersistentVolume, error) { + parameters := vac.Parameters if ctrl.extraModifyMetadata { - vac.Parameters[pvcNameKey] = pvc.GetName() - vac.Parameters[pvcNamespaceKey] = pvc.GetNamespace() - vac.Parameters[pvNameKey] = pv.GetName() + if len(parameters) == 0 { + parameters = make(map[string]string, 3) + } else { + parameters = maps.Clone(parameters) + } + parameters[pvcNameKey] = pvc.GetName() + parameters[pvcNamespaceKey] = pvc.GetNamespace() + parameters[pvNameKey] = pv.GetName() } - err := ctrl.modifier.Modify(pv, vac.Parameters) + err := ctrl.modifier.Modify(pv, parameters) if err != nil { return pvc, pv, err diff --git a/pkg/modifycontroller/modify_volume_test.go b/pkg/modifycontroller/modify_volume_test.go index baffbe3ec..fd1f991d8 100644 --- a/pkg/modifycontroller/modify_volume_test.go +++ b/pkg/modifycontroller/modify_volume_test.go @@ -25,12 +25,7 @@ var ( targetVacObject = &storagev1beta1.VolumeAttributesClass{ ObjectMeta: metav1.ObjectMeta{Name: targetVac}, DriverName: testDriverName, - Parameters: map[string]string{ - "iops": "4567", - "csi.storage.k8s.io/pvc/name": pvcName, - "csi.storage.k8s.io/pvc/namespace": pvcNamespace, - "csi.storage.k8s.io/pv/name": pvName, - }, + Parameters: map[string]string{"iops": "4567"}, } ) @@ -48,7 +43,7 @@ func TestModify(t *testing.T) { expectedCurrentVolumeAttributesClassName *string expectedPVVolumeAttributesClassName *string withExtraMetadata bool - expectedVacParams map[string]string + expectedMutableParams map[string]string }{ { name: "nothing to modify", @@ -80,6 +75,7 @@ func TestModify(t *testing.T) { expectedModifyVolumeStatus: nil, expectedCurrentVolumeAttributesClassName: &targetVac, expectedPVVolumeAttributesClassName: &targetVac, + expectedMutableParams: map[string]string{"iops": "4567"}, }, { name: "modify volume success with extra metadata", @@ -91,7 +87,7 @@ func TestModify(t *testing.T) { expectedCurrentVolumeAttributesClassName: &targetVac, expectedPVVolumeAttributesClassName: &targetVac, withExtraMetadata: true, - expectedVacParams: map[string]string{ + expectedMutableParams: map[string]string{ "iops": "4567", "csi.storage.k8s.io/pvc/name": basePVC.GetName(), "csi.storage.k8s.io/pvc/namespace": basePVC.GetNamespace(), @@ -104,12 +100,13 @@ func TestModify(t *testing.T) { test := tests[i] t.Run(test.name, func(t *testing.T) { // Setup - client := csi.NewMockClient(testDriverName, true, true, true, true, true, test.withExtraMetadata) + client := csi.NewMockClient(testDriverName, true, true, true, true, true) initialObjects := []runtime.Object{test.pvc, test.pv, testVacObject} if test.vacExists { initialObjects = append(initialObjects, targetVacObject) } ctrlInstance := setupFakeK8sEnvironment(t, client, initialObjects) + ctrlInstance.extraModifyMetadata = test.withExtraMetadata // Action pvc, pv, err, modifyCalled := ctrlInstance.modify(test.pvc, test.pv) @@ -138,15 +135,10 @@ func TestModify(t *testing.T) { t.Errorf("expected VolumeAttributesClassName of pv to be %v, got %v", *test.expectedPVVolumeAttributesClassName, *actualPVVolumeAttributesClassName) } - if test.withExtraMetadata { - vacObj, err := ctrlInstance.vacLister.Get(*test.expectedPVVolumeAttributesClassName) - if err != nil { - t.Errorf("failed to get VAC: %v", err) - } else { - vacParams := vacObj.Parameters - if diff := cmp.Diff(test.expectedVacParams, vacParams); diff != "" { - t.Errorf("expected VAC parameters to be %v, got %v", test.expectedVacParams, vacParams) - } + if test.expectedMutableParams != nil { + p := client.GetModifiedParameters() + if diff := cmp.Diff(test.expectedMutableParams, p); diff != "" { + t.Errorf("expected mutable parameters to be %v, got %v", test.expectedMutableParams, p) } } }) diff --git a/pkg/resizer/csi_resizer_test.go b/pkg/resizer/csi_resizer_test.go index 6d089511f..ea6b9e023 100644 --- a/pkg/resizer/csi_resizer_test.go +++ b/pkg/resizer/csi_resizer_test.go @@ -72,7 +72,7 @@ func TestNewResizer(t *testing.T) { Error: resizeNotSupportErr, }, } { - client := csi.NewMockClient("mock", c.SupportsNodeResize, c.SupportsControllerResize, false, c.SupportsPluginControllerService, c.SupportsControllerSingleNodeMultiWriter, false) + client := csi.NewMockClient("mock", c.SupportsNodeResize, c.SupportsControllerResize, false, c.SupportsPluginControllerService, c.SupportsControllerSingleNodeMultiWriter) driverName := "mock-driver" k8sClient := fake.NewSimpleClientset() resizer, err := NewResizerFromClient(client, 0, k8sClient, driverName) @@ -106,7 +106,7 @@ func TestResizeWithSecret(t *testing.T) { }, } for _, tc := range tests { - client := csi.NewMockClient("mock", true, true, false, true, true, false) + client := csi.NewMockClient("mock", true, true, false, true, true) secret := makeSecret("some-secret", "secret-namespace") k8sClient := fake.NewSimpleClientset(secret) pv := makeTestPV("test-csi", 2, "ebs-csi", "vol-abcde", tc.hasExpansionSecret) @@ -164,7 +164,7 @@ func TestResizeMigratedPV(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { driverName := tc.driverName - client := csi.NewMockClient(driverName, true, true, false, true, true, false) + client := csi.NewMockClient(driverName, true, true, false, true, true) client.SetCheckMigratedLabel() k8sClient := fake.NewSimpleClientset() resizer, err := NewResizerFromClient(client, 0, k8sClient, driverName) @@ -433,7 +433,7 @@ func TestCanSupport(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { driverName := tc.driverName - client := csi.NewMockClient(driverName, true, true, false, true, true, false) + client := csi.NewMockClient(driverName, true, true, false, true, true) k8sClient := fake.NewSimpleClientset() resizer, err := NewResizerFromClient(client, 0, k8sClient, driverName) if err != nil {