diff --git a/pkg/csi/service/common/constants.go b/pkg/csi/service/common/constants.go index 33e513a857..d6c97c0f29 100644 --- a/pkg/csi/service/common/constants.go +++ b/pkg/csi/service/common/constants.go @@ -353,6 +353,16 @@ const ( // Guest cluster. SupervisorVolumeSnapshotAnnotationKey = "csi.vsphere.guest-initiated-csi-snapshot" + // MaxSnapshotsPerVolumeAnnotationKey represents the annotation key on Namespace CR + // in Supervisor cluster to specify the maximum number of snapshots per volume + MaxSnapshotsPerVolumeAnnotationKey = "csi.vsphere.max-snapshots-per-volume" + + // DefaultMaxSnapshotsPerBlockVolumeInWCP is the default maximum number of snapshots per block volume in WCP + DefaultMaxSnapshotsPerBlockVolumeInWCP = 4 + + // MaxAllowedSnapshotsPerBlockVolume is the hard cap for maximum snapshots per block volume + MaxAllowedSnapshotsPerBlockVolume = 32 + // AttributeSupervisorVolumeSnapshotClass represents name of VolumeSnapshotClass AttributeSupervisorVolumeSnapshotClass = "svvolumesnapshotclass" @@ -467,6 +477,8 @@ const ( // WCPVMServiceVMSnapshots is a supervisor capability indicating // if supports_VM_service_VM_snapshots FSS is enabled WCPVMServiceVMSnapshots = "supports_VM_service_VM_snapshots" + // SnapshotLimitWCP is an internal FSS that enables snapshot limit enforcement in WCP + SnapshotLimitWCP = "snapshot-limit-wcp" ) var WCPFeatureStates = map[string]struct{}{ diff --git a/pkg/csi/service/wcp/controller.go b/pkg/csi/service/wcp/controller.go index 05ba648230..9dbf939f14 100644 --- a/pkg/csi/service/wcp/controller.go +++ b/pkg/csi/service/wcp/controller.go @@ -98,10 +98,23 @@ var ( vmMoidToHostMoid, volumeIDToVMMap map[string]string ) +// volumeLock represents a lock for a specific volume with reference counting +type volumeLock struct { + mutex sync.Mutex + refCount int +} + +// snapshotLockManager manages per-volume locks for snapshot operations +type snapshotLockManager struct { + locks map[string]*volumeLock + mapMutex sync.RWMutex +} + type controller struct { - manager *common.Manager - authMgr common.AuthorizationService - topologyMgr commoncotypes.ControllerTopologyService + manager *common.Manager + authMgr common.AuthorizationService + topologyMgr commoncotypes.ControllerTopologyService + snapshotLockMgr *snapshotLockManager csi.UnimplementedControllerServer } @@ -211,6 +224,12 @@ func (c *controller) Init(config *cnsconfig.Config, version string) error { CryptoClient: cryptoClient, } + // Initialize snapshot lock manager + c.snapshotLockMgr = &snapshotLockManager{ + locks: make(map[string]*volumeLock), + } + log.Info("Initialized snapshot lock manager for per-volume serialization") + vc, err := common.GetVCenter(ctx, c.manager) if err != nil { log.Errorf("failed to get vcenter. err=%v", err) @@ -447,6 +466,53 @@ func (c *controller) ReloadConfiguration(reconnectToVCFromNewConfig bool) error return nil } +// acquireSnapshotLock acquires a lock for the given volume ID. +// It creates a new lock if one doesn't exist and increments the reference count. +// The caller must call releaseSnapshotLock when done. +func (c *controller) acquireSnapshotLock(ctx context.Context, volumeID string) { + log := logger.GetLogger(ctx) + c.snapshotLockMgr.mapMutex.Lock() + defer c.snapshotLockMgr.mapMutex.Unlock() + + vLock, exists := c.snapshotLockMgr.locks[volumeID] + if !exists { + vLock = &volumeLock{} + c.snapshotLockMgr.locks[volumeID] = vLock + log.Debugf("Created new lock for volume %q", volumeID) + } + vLock.refCount++ + log.Debugf("Acquired lock for volume %q, refCount: %d", volumeID, vLock.refCount) + + // Unlock the map before acquiring the volume lock to avoid deadlock + c.snapshotLockMgr.mapMutex.Unlock() + vLock.mutex.Lock() + c.snapshotLockMgr.mapMutex.Lock() +} + +// releaseSnapshotLock releases the lock for the given volume ID. +// It decrements the reference count and removes the lock if count reaches zero. +func (c *controller) releaseSnapshotLock(ctx context.Context, volumeID string) { + log := logger.GetLogger(ctx) + c.snapshotLockMgr.mapMutex.Lock() + defer c.snapshotLockMgr.mapMutex.Unlock() + + vLock, exists := c.snapshotLockMgr.locks[volumeID] + if !exists { + log.Warnf("Attempted to release non-existent lock for volume %q", volumeID) + return + } + + vLock.mutex.Unlock() + vLock.refCount-- + log.Debugf("Released lock for volume %q, refCount: %d", volumeID, vLock.refCount) + + // Clean up the lock if reference count reaches zero + if vLock.refCount == 0 { + delete(c.snapshotLockMgr.locks, volumeID) + log.Debugf("Cleaned up lock for volume %q", volumeID) + } +} + // createBlockVolume creates a block volume based on the CreateVolumeRequest. func (c *controller) createBlockVolume(ctx context.Context, req *csi.CreateVolumeRequest, isWorkloadDomainIsolationEnabled bool, clusterMoIds []string) ( @@ -2446,8 +2512,47 @@ func (c *controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshot "Queried VolumeType: %v", volumeType, cnsVolumeDetailsMap[volumeID].VolumeType) } - // TODO: We may need to add logic to check the limit of max number of snapshots by using - // GlobalMaxSnapshotsPerBlockVolume etc. variables in the future. + // Acquire lock for this volume to serialize snapshot operations + // Check snapshot limit if the feature is enabled + isSnapshotLimitWCPEnabled := commonco.ContainerOrchestratorUtility.IsFSSEnabled(ctx, common.SnapshotLimitWCP) + if isSnapshotLimitWCPEnabled { + c.acquireSnapshotLock(ctx, volumeID) + defer c.releaseSnapshotLock(ctx, volumeID) + + // Extract namespace from request parameters + volumeSnapshotNamespace := req.Parameters[common.VolumeSnapshotNamespaceKey] + if volumeSnapshotNamespace == "" { + return nil, logger.LogNewErrorCodef(log, codes.Internal, + "volumesnapshot namespace is not set in the request parameters") + } + + // Get snapshot limit from namespace annotation + snapshotLimit, err := getSnapshotLimitFromNamespace(ctx, volumeSnapshotNamespace) + if err != nil { + return nil, logger.LogNewErrorCodef(log, codes.Internal, + "failed to get snapshot limit for namespace %q: %v", volumeSnapshotNamespace, err) + } + log.Infof("Snapshot limit for namespace %q is set to %d", volumeSnapshotNamespace, snapshotLimit) + + // Query existing snapshots for this volume + snapshotList, _, err := common.QueryVolumeSnapshotsByVolumeID(ctx, c.manager.VolumeManager, volumeID, + common.QuerySnapshotLimit) + if err != nil { + return nil, logger.LogNewErrorCodef(log, codes.Internal, + "failed to query snapshots for volume %q: %v", volumeID, err) + } + + // Check if the limit is exceeded + currentSnapshotCount := len(snapshotList) + if currentSnapshotCount >= snapshotLimit { + return nil, logger.LogNewErrorCodef(log, codes.FailedPrecondition, + "the number of snapshots (%d) on the source volume %s has reached or exceeded "+ + "the configured maximum (%d) for namespace %s", + currentSnapshotCount, volumeID, snapshotLimit, volumeSnapshotNamespace) + } + log.Infof("Current snapshot count for volume %q is %d, within limit of %d", + volumeID, currentSnapshotCount, snapshotLimit) + } // the returned snapshotID below is a combination of CNS VolumeID and CNS SnapshotID concatenated by the "+" // sign. That is, a string of "+". Because, all other CNS snapshot APIs still require both diff --git a/pkg/csi/service/wcp/controller_helper.go b/pkg/csi/service/wcp/controller_helper.go index 541e4493e6..2d67ea69a7 100644 --- a/pkg/csi/service/wcp/controller_helper.go +++ b/pkg/csi/service/wcp/controller_helper.go @@ -37,6 +37,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes" + restclient "k8s.io/client-go/rest" api "k8s.io/kubernetes/pkg/apis/core" "sigs.k8s.io/controller-runtime/pkg/client/config" spv1alpha1 "sigs.k8s.io/vsphere-csi-driver/v3/pkg/apis/storagepool/cns/v1alpha1" @@ -849,6 +851,15 @@ func validateControllerPublishVolumeRequesInWcp(ctx context.Context, req *csi.Co var newK8sClient = k8s.NewClient +// getK8sConfig is a variable that can be overridden for testing +var getK8sConfig = config.GetConfig + +// newK8sClientFromConfig is a variable that can be overridden for testing +// It wraps kubernetes.NewForConfig and returns Interface for easier testing +var newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return kubernetes.NewForConfig(c) +} + // getPodVMUUID returns the UUID of the VM(running on the node) on which the pod that is trying to // use the volume is scheduled. func getPodVMUUID(ctx context.Context, volumeID, nodeName string) (string, error) { @@ -950,3 +961,63 @@ func GetZonesFromAccessibilityRequirements(ctx context.Context, } return zones, nil } + +// getSnapshotLimitFromNamespace retrieves the snapshot limit from the namespace annotation. +// If the annotation is not present, it returns the default value. +// If the annotation value exceeds the maximum allowed, it caps the value and logs a warning. +func getSnapshotLimitFromNamespace(ctx context.Context, namespace string) (int, error) { + log := logger.GetLogger(ctx) + + // Get Kubernetes config + cfg, err := getK8sConfig() + if err != nil { + return 0, logger.LogNewErrorCodef(log, codes.Internal, + "failed to get Kubernetes config: %v", err) + } + + // Create Kubernetes clientset + k8sClient, err := newK8sClientFromConfig(cfg) + if err != nil { + return 0, logger.LogNewErrorCodef(log, codes.Internal, + "failed to create Kubernetes client: %v", err) + } + + // Get namespace object + ns, err := k8sClient.CoreV1().Namespaces().Get(ctx, namespace, metav1.GetOptions{}) + if err != nil { + return 0, logger.LogNewErrorCodef(log, codes.Internal, + "failed to get namespace %q: %v", namespace, err) + } + + // Check if annotation exists + annotationValue, exists := ns.Annotations[common.MaxSnapshotsPerVolumeAnnotationKey] + if !exists { + log.Infof("Annotation %q not found in namespace %q, using default value %d", + common.MaxSnapshotsPerVolumeAnnotationKey, namespace, common.DefaultMaxSnapshotsPerBlockVolumeInWCP) + return common.DefaultMaxSnapshotsPerBlockVolumeInWCP, nil + } + + // Parse annotation value + limit, err := strconv.Atoi(annotationValue) + if err != nil { + return 0, logger.LogNewErrorCodef(log, codes.Internal, + "failed to parse annotation %q value %q in namespace %q: %v", + common.MaxSnapshotsPerVolumeAnnotationKey, annotationValue, namespace, err) + } + + // Validate limit + if limit < 0 { + return 0, logger.LogNewErrorCodef(log, codes.InvalidArgument, + "invalid snapshot limit %d in namespace %q: must be >= 0", limit, namespace) + } + + // Cap to maximum allowed value + if limit > common.MaxAllowedSnapshotsPerBlockVolume { + log.Warnf("Snapshot limit %d in namespace %q exceeds maximum allowed %d, capping to %d", + limit, namespace, common.MaxAllowedSnapshotsPerBlockVolume, common.MaxAllowedSnapshotsPerBlockVolume) + return common.MaxAllowedSnapshotsPerBlockVolume, nil + } + + log.Infof("Snapshot limit for namespace %q is set to %d", namespace, limit) + return limit, nil +} diff --git a/pkg/csi/service/wcp/controller_helper_test.go b/pkg/csi/service/wcp/controller_helper_test.go index 4eaa2dfcfd..a1bb943f2e 100644 --- a/pkg/csi/service/wcp/controller_helper_test.go +++ b/pkg/csi/service/wcp/controller_helper_test.go @@ -5,13 +5,17 @@ import ( "testing" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/fake" + restclient "k8s.io/client-go/rest" k8stesting "k8s.io/client-go/testing" "sigs.k8s.io/vsphere-csi-driver/v3/pkg/common/unittestcommon" + "sigs.k8s.io/vsphere-csi-driver/v3/pkg/csi/service/common" "sigs.k8s.io/vsphere-csi-driver/v3/pkg/csi/service/common/commonco" ) @@ -158,3 +162,218 @@ func newMockPod(name, namespace, nodeName string, volumes []string, }, } } + +func TestGetSnapshotLimitFromNamespace(t *testing.T) { + // Save original functions and restore after tests + originalGetConfig := getK8sConfig + originalNewK8sClientFromConfig := newK8sClientFromConfig + defer func() { + getK8sConfig = originalGetConfig + newK8sClientFromConfig = originalNewK8sClientFromConfig + }() + + // Mock getK8sConfig to return a fake config + getK8sConfig = func() (*restclient.Config, error) { + return &restclient.Config{}, nil + } + + t.Run("WhenAnnotationExists_ValidValue", func(t *testing.T) { + // Setup + ns := &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-namespace", + Annotations: map[string]string{ + common.MaxSnapshotsPerVolumeAnnotationKey: "5", + }, + }, + } + fakeClient := fake.NewSimpleClientset(ns) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitFromNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, 5, limit) + }) + + t.Run("WhenAnnotationExists_ValueEqualsMax", func(t *testing.T) { + // Setup + ns := &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-namespace", + Annotations: map[string]string{ + common.MaxSnapshotsPerVolumeAnnotationKey: "32", + }, + }, + } + fakeClient := fake.NewSimpleClientset(ns) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitFromNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, 32, limit) + }) + + t.Run("WhenAnnotationExists_ValueExceedsMax", func(t *testing.T) { + // Setup + ns := &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-namespace", + Annotations: map[string]string{ + common.MaxSnapshotsPerVolumeAnnotationKey: "50", + }, + }, + } + fakeClient := fake.NewSimpleClientset(ns) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitFromNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, 32, limit) // Should be capped + }) + + t.Run("WhenAnnotationExists_ValueIsZero", func(t *testing.T) { + // Setup + ns := &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-namespace", + Annotations: map[string]string{ + common.MaxSnapshotsPerVolumeAnnotationKey: "0", + }, + }, + } + fakeClient := fake.NewSimpleClientset(ns) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitFromNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, 0, limit) // 0 means block all snapshots + }) + + t.Run("WhenAnnotationExists_ValueIsNegative", func(t *testing.T) { + // Setup + ns := &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-namespace", + Annotations: map[string]string{ + common.MaxSnapshotsPerVolumeAnnotationKey: "-5", + }, + }, + } + fakeClient := fake.NewSimpleClientset(ns) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + _, err := getSnapshotLimitFromNamespace(context.Background(), "test-namespace") + + // Verify + assert.NotNil(t, err) + statusErr, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, codes.InvalidArgument, statusErr.Code()) + assert.Contains(t, err.Error(), "invalid snapshot limit") + }) + + t.Run("WhenAnnotationExists_InvalidFormat", func(t *testing.T) { + // Setup + ns := &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-namespace", + Annotations: map[string]string{ + common.MaxSnapshotsPerVolumeAnnotationKey: "abc", + }, + }, + } + fakeClient := fake.NewSimpleClientset(ns) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + _, err := getSnapshotLimitFromNamespace(context.Background(), "test-namespace") + + // Verify + assert.NotNil(t, err) + statusErr, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, codes.Internal, statusErr.Code()) + assert.Contains(t, err.Error(), "failed to parse annotation") + }) + + t.Run("WhenAnnotationMissing", func(t *testing.T) { + // Setup + ns := &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-namespace", + Annotations: map[string]string{}, // No annotation + }, + } + fakeClient := fake.NewSimpleClientset(ns) + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + limit, err := getSnapshotLimitFromNamespace(context.Background(), "test-namespace") + + // Verify + assert.Nil(t, err) + assert.Equal(t, common.DefaultMaxSnapshotsPerBlockVolumeInWCP, limit) // Should return default (4) + }) + + t.Run("WhenNamespaceNotFound", func(t *testing.T) { + // Setup + fakeClient := fake.NewSimpleClientset() // Empty clientset + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return fakeClient, nil + } + + // Execute + _, err := getSnapshotLimitFromNamespace(context.Background(), "non-existent-namespace") + + // Verify + assert.NotNil(t, err) + statusErr, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, codes.Internal, statusErr.Code()) + assert.Contains(t, err.Error(), "failed to get namespace") + }) + + t.Run("WhenK8sClientCreationFails", func(t *testing.T) { + // Setup + newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) { + return nil, assert.AnError + } + + // Execute + _, err := getSnapshotLimitFromNamespace(context.Background(), "test-namespace") + + // Verify + assert.NotNil(t, err) + statusErr, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, codes.Internal, statusErr.Code()) + assert.Contains(t, err.Error(), "failed to create Kubernetes client") + }) +} diff --git a/pkg/csi/service/wcp/controller_test.go b/pkg/csi/service/wcp/controller_test.go index 131f31aaab..31cdd678d2 100644 --- a/pkg/csi/service/wcp/controller_test.go +++ b/pkg/csi/service/wcp/controller_test.go @@ -22,6 +22,7 @@ import ( "strings" "sync" "testing" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -137,6 +138,9 @@ func getControllerTest(t *testing.T) *controllerTest { c := &controller{ manager: manager, topologyMgr: topologyMgr, + snapshotLockMgr: &snapshotLockManager{ + locks: make(map[string]*volumeLock), + }, } controllerTestInstance = &controllerTest{ @@ -1602,3 +1606,210 @@ func TestWCPExpandVolumeWithSnapshots(t *testing.T) { t.Fatalf("Volume should not exist after deletion with ID: %s", volID) } } + +func TestSnapshotLockManager(t *testing.T) { + ct := getControllerTest(t) + + t.Run("AcquireAndRelease_SingleVolume", func(t *testing.T) { + volumeID := "test-volume-1" + + // Acquire lock + ct.controller.acquireSnapshotLock(ctx, volumeID) + + // Verify lock exists and refCount = 1 + ct.controller.snapshotLockMgr.mapMutex.RLock() + vLock, exists := ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if !exists { + t.Fatal("Lock should exist after acquire") + } + if vLock.refCount != 1 { + t.Fatalf("Expected refCount=1, got %d", vLock.refCount) + } + + // Release lock + ct.controller.releaseSnapshotLock(ctx, volumeID) + + // Verify lock is removed + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists = ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if exists { + t.Fatal("Lock should be removed after release") + } + }) + + t.Run("AcquireMultipleTimes_SameVolume", func(t *testing.T) { + volumeID := "test-volume-2" + + // Use two goroutines to acquire the lock + var wg sync.WaitGroup + acquired := make(chan bool, 2) + + // First goroutine acquires and holds the lock + wg.Add(1) + go func() { + defer wg.Done() + ct.controller.acquireSnapshotLock(ctx, volumeID) + acquired <- true + // Hold lock briefly + time.Sleep(100 * time.Millisecond) + ct.controller.releaseSnapshotLock(ctx, volumeID) + }() + + // Wait for first goroutine to acquire + <-acquired + + // Verify refCount = 1, lock exists + ct.controller.snapshotLockMgr.mapMutex.RLock() + vLock, exists := ct.controller.snapshotLockMgr.locks[volumeID] + refCount1 := vLock.refCount + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if !exists { + t.Fatal("Lock should exist") + } + if refCount1 != 1 { + t.Fatalf("Expected refCount=1, got %d", refCount1) + } + + // Second goroutine tries to acquire (will be blocked) + wg.Add(1) + go func() { + defer wg.Done() + ct.controller.acquireSnapshotLock(ctx, volumeID) + acquired <- true + ct.controller.releaseSnapshotLock(ctx, volumeID) + }() + + // Give second goroutine time to start waiting + time.Sleep(50 * time.Millisecond) + + // Verify refCount increased to 2 (second goroutine is waiting) + ct.controller.snapshotLockMgr.mapMutex.RLock() + vLock, exists = ct.controller.snapshotLockMgr.locks[volumeID] + refCount2 := vLock.refCount + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if !exists { + t.Fatal("Lock should exist") + } + if refCount2 != 2 { + t.Fatalf("Expected refCount=2, got %d", refCount2) + } + + // Wait for both goroutines to complete + wg.Wait() + + // Verify lock is removed after both releases + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists = ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if exists { + t.Fatal("Lock should be removed after all releases") + } + }) + + t.Run("AcquireRelease_MultipleVolumes", func(t *testing.T) { + volume1 := "test-volume-3" + volume2 := "test-volume-4" + volume3 := "test-volume-5" + + // Acquire locks for all volumes + ct.controller.acquireSnapshotLock(ctx, volume1) + ct.controller.acquireSnapshotLock(ctx, volume2) + ct.controller.acquireSnapshotLock(ctx, volume3) + + // Verify all locks exist + ct.controller.snapshotLockMgr.mapMutex.RLock() + count := len(ct.controller.snapshotLockMgr.locks) + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if count < 3 { + t.Fatalf("Expected at least 3 locks, got %d", count) + } + + // Release volume2 + ct.controller.releaseSnapshotLock(ctx, volume2) + + // Verify volume2 removed, others remain + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists1 := ct.controller.snapshotLockMgr.locks[volume1] + _, exists2 := ct.controller.snapshotLockMgr.locks[volume2] + _, exists3 := ct.controller.snapshotLockMgr.locks[volume3] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if !exists1 { + t.Fatal("Volume1 lock should still exist") + } + if exists2 { + t.Fatal("Volume2 lock should be removed") + } + if !exists3 { + t.Fatal("Volume3 lock should still exist") + } + + // Cleanup + ct.controller.releaseSnapshotLock(ctx, volume1) + ct.controller.releaseSnapshotLock(ctx, volume3) + }) + + t.Run("ConcurrentAccess_SameVolume", func(t *testing.T) { + volumeID := "test-volume-concurrent" + counter := 0 + var wg sync.WaitGroup + goroutines := 5 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ct.controller.acquireSnapshotLock(ctx, volumeID) + defer ct.controller.releaseSnapshotLock(ctx, volumeID) + + // Critical section - increment counter + temp := counter + // Simulate some work + for j := 0; j < 100; j++ { + _ = j * 2 + } + counter = temp + 1 + }() + } + + wg.Wait() + + // Verify counter = goroutines (no race condition) + if counter != goroutines { + t.Fatalf("Expected counter=%d, got %d (race condition detected)", goroutines, counter) + } + + // Verify lock is cleaned up + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists := ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if exists { + t.Fatal("Lock should be cleaned up after all goroutines finish") + } + }) + + t.Run("ReleaseNonExistentLock", func(t *testing.T) { + volumeID := "non-existent-volume" + + // This should not panic + ct.controller.releaseSnapshotLock(ctx, volumeID) + + // Verify no lock was created + ct.controller.snapshotLockMgr.mapMutex.RLock() + _, exists := ct.controller.snapshotLockMgr.locks[volumeID] + ct.controller.snapshotLockMgr.mapMutex.RUnlock() + + if exists { + t.Fatal("Lock should not exist after releasing non-existent lock") + } + }) +}