Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pkg/csi/service/common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,16 @@ const (
// Guest cluster.
SupervisorVolumeSnapshotAnnotationKey = "csi.vsphere.guest-initiated-csi-snapshot"

// MaxSnapshotsPerVolumeAnnotationKey represents the annotation key on Namespace CR
// in Supervisor cluster to specify the maximum number of snapshots per volume
MaxSnapshotsPerVolumeAnnotationKey = "csi.vsphere.max-snapshots-per-volume"

// DefaultMaxSnapshotsPerBlockVolumeInWCP is the default maximum number of snapshots per block volume in WCP
DefaultMaxSnapshotsPerBlockVolumeInWCP = 4

// MaxAllowedSnapshotsPerBlockVolume is the hard cap for maximum snapshots per block volume
MaxAllowedSnapshotsPerBlockVolume = 32

// AttributeSupervisorVolumeSnapshotClass represents name of VolumeSnapshotClass
AttributeSupervisorVolumeSnapshotClass = "svvolumesnapshotclass"

Expand Down Expand Up @@ -467,6 +477,8 @@ const (
// WCPVMServiceVMSnapshots is a supervisor capability indicating
// if supports_VM_service_VM_snapshots FSS is enabled
WCPVMServiceVMSnapshots = "supports_VM_service_VM_snapshots"
// SnapshotLimitWCP is an internal FSS that enables snapshot limit enforcement in WCP
SnapshotLimitWCP = "snapshot-limit-wcp"
)

var WCPFeatureStates = map[string]struct{}{
Expand Down
115 changes: 110 additions & 5 deletions pkg/csi/service/wcp/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,23 @@ var (
vmMoidToHostMoid, volumeIDToVMMap map[string]string
)

// volumeLock represents a lock for a specific volume with reference counting
type volumeLock struct {
mutex sync.Mutex
refCount int
}

// snapshotLockManager manages per-volume locks for snapshot operations
type snapshotLockManager struct {
locks map[string]*volumeLock
mapMutex sync.RWMutex
}

type controller struct {
manager *common.Manager
authMgr common.AuthorizationService
topologyMgr commoncotypes.ControllerTopologyService
manager *common.Manager
authMgr common.AuthorizationService
topologyMgr commoncotypes.ControllerTopologyService
snapshotLockMgr *snapshotLockManager
csi.UnimplementedControllerServer
}

Expand Down Expand Up @@ -211,6 +224,12 @@ func (c *controller) Init(config *cnsconfig.Config, version string) error {
CryptoClient: cryptoClient,
}

// Initialize snapshot lock manager
c.snapshotLockMgr = &snapshotLockManager{
locks: make(map[string]*volumeLock),
}
log.Info("Initialized snapshot lock manager for per-volume serialization")

vc, err := common.GetVCenter(ctx, c.manager)
if err != nil {
log.Errorf("failed to get vcenter. err=%v", err)
Expand Down Expand Up @@ -447,6 +466,53 @@ func (c *controller) ReloadConfiguration(reconnectToVCFromNewConfig bool) error
return nil
}

// acquireSnapshotLock acquires a lock for the given volume ID.
// It creates a new lock if one doesn't exist and increments the reference count.
// The caller must call releaseSnapshotLock when done.
func (c *controller) acquireSnapshotLock(ctx context.Context, volumeID string) {
log := logger.GetLogger(ctx)
c.snapshotLockMgr.mapMutex.Lock()
defer c.snapshotLockMgr.mapMutex.Unlock()

vLock, exists := c.snapshotLockMgr.locks[volumeID]
if !exists {
vLock = &volumeLock{}
c.snapshotLockMgr.locks[volumeID] = vLock
log.Debugf("Created new lock for volume %q", volumeID)
}
vLock.refCount++
log.Debugf("Acquired lock for volume %q, refCount: %d", volumeID, vLock.refCount)

// Unlock the map before acquiring the volume lock to avoid deadlock
c.snapshotLockMgr.mapMutex.Unlock()
vLock.mutex.Lock()
c.snapshotLockMgr.mapMutex.Lock()
}

// releaseSnapshotLock releases the lock for the given volume ID.
// It decrements the reference count and removes the lock if count reaches zero.
func (c *controller) releaseSnapshotLock(ctx context.Context, volumeID string) {
log := logger.GetLogger(ctx)
c.snapshotLockMgr.mapMutex.Lock()
defer c.snapshotLockMgr.mapMutex.Unlock()

vLock, exists := c.snapshotLockMgr.locks[volumeID]
if !exists {
log.Warnf("Attempted to release non-existent lock for volume %q", volumeID)
return
}

vLock.mutex.Unlock()
vLock.refCount--
log.Debugf("Released lock for volume %q, refCount: %d", volumeID, vLock.refCount)

// Clean up the lock if reference count reaches zero
if vLock.refCount == 0 {
delete(c.snapshotLockMgr.locks, volumeID)
log.Debugf("Cleaned up lock for volume %q", volumeID)
}
}

// createBlockVolume creates a block volume based on the CreateVolumeRequest.
func (c *controller) createBlockVolume(ctx context.Context, req *csi.CreateVolumeRequest,
isWorkloadDomainIsolationEnabled bool, clusterMoIds []string) (
Expand Down Expand Up @@ -2446,8 +2512,47 @@ func (c *controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshot
"Queried VolumeType: %v", volumeType, cnsVolumeDetailsMap[volumeID].VolumeType)
}

// TODO: We may need to add logic to check the limit of max number of snapshots by using
// GlobalMaxSnapshotsPerBlockVolume etc. variables in the future.
// Acquire lock for this volume to serialize snapshot operations
// Check snapshot limit if the feature is enabled
isSnapshotLimitWCPEnabled := commonco.ContainerOrchestratorUtility.IsFSSEnabled(ctx, common.SnapshotLimitWCP)
if isSnapshotLimitWCPEnabled {
c.acquireSnapshotLock(ctx, volumeID)
defer c.releaseSnapshotLock(ctx, volumeID)

// Extract namespace from request parameters
volumeSnapshotNamespace := req.Parameters[common.VolumeSnapshotNamespaceKey]
if volumeSnapshotNamespace == "" {
return nil, logger.LogNewErrorCodef(log, codes.Internal,
"volumesnapshot namespace is not set in the request parameters")
}

// Get snapshot limit from namespace annotation
snapshotLimit, err := getSnapshotLimitFromNamespace(ctx, volumeSnapshotNamespace)
if err != nil {
return nil, logger.LogNewErrorCodef(log, codes.Internal,
"failed to get snapshot limit for namespace %q: %v", volumeSnapshotNamespace, err)
}
log.Infof("Snapshot limit for namespace %q is set to %d", volumeSnapshotNamespace, snapshotLimit)

// Query existing snapshots for this volume
snapshotList, _, err := common.QueryVolumeSnapshotsByVolumeID(ctx, c.manager.VolumeManager, volumeID,
common.QuerySnapshotLimit)
if err != nil {
return nil, logger.LogNewErrorCodef(log, codes.Internal,
"failed to query snapshots for volume %q: %v", volumeID, err)
}

// Check if the limit is exceeded
currentSnapshotCount := len(snapshotList)
if currentSnapshotCount >= snapshotLimit {
return nil, logger.LogNewErrorCodef(log, codes.FailedPrecondition,
"the number of snapshots (%d) on the source volume %s has reached or exceeded "+
"the configured maximum (%d) for namespace %s",
currentSnapshotCount, volumeID, snapshotLimit, volumeSnapshotNamespace)
}
log.Infof("Current snapshot count for volume %q is %d, within limit of %d",
volumeID, currentSnapshotCount, snapshotLimit)
}

// the returned snapshotID below is a combination of CNS VolumeID and CNS SnapshotID concatenated by the "+"
// sign. That is, a string of "<UUID>+<UUID>". Because, all other CNS snapshot APIs still require both
Expand Down
71 changes: 71 additions & 0 deletions pkg/csi/service/wcp/controller_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes"
restclient "k8s.io/client-go/rest"
api "k8s.io/kubernetes/pkg/apis/core"
"sigs.k8s.io/controller-runtime/pkg/client/config"
spv1alpha1 "sigs.k8s.io/vsphere-csi-driver/v3/pkg/apis/storagepool/cns/v1alpha1"
Expand Down Expand Up @@ -849,6 +851,15 @@ func validateControllerPublishVolumeRequesInWcp(ctx context.Context, req *csi.Co

var newK8sClient = k8s.NewClient

// getK8sConfig is a variable that can be overridden for testing
var getK8sConfig = config.GetConfig

// newK8sClientFromConfig is a variable that can be overridden for testing
// It wraps kubernetes.NewForConfig and returns Interface for easier testing
var newK8sClientFromConfig = func(c *restclient.Config) (kubernetes.Interface, error) {
return kubernetes.NewForConfig(c)
}

// getPodVMUUID returns the UUID of the VM(running on the node) on which the pod that is trying to
// use the volume is scheduled.
func getPodVMUUID(ctx context.Context, volumeID, nodeName string) (string, error) {
Expand Down Expand Up @@ -950,3 +961,63 @@ func GetZonesFromAccessibilityRequirements(ctx context.Context,
}
return zones, nil
}

// getSnapshotLimitFromNamespace retrieves the snapshot limit from the namespace annotation.
// If the annotation is not present, it returns the default value.
// If the annotation value exceeds the maximum allowed, it caps the value and logs a warning.
func getSnapshotLimitFromNamespace(ctx context.Context, namespace string) (int, error) {
log := logger.GetLogger(ctx)

// Get Kubernetes config
cfg, err := getK8sConfig()
if err != nil {
return 0, logger.LogNewErrorCodef(log, codes.Internal,
"failed to get Kubernetes config: %v", err)
}

// Create Kubernetes clientset
k8sClient, err := newK8sClientFromConfig(cfg)
if err != nil {
return 0, logger.LogNewErrorCodef(log, codes.Internal,
"failed to create Kubernetes client: %v", err)
}

// Get namespace object
ns, err := k8sClient.CoreV1().Namespaces().Get(ctx, namespace, metav1.GetOptions{})
if err != nil {
return 0, logger.LogNewErrorCodef(log, codes.Internal,
"failed to get namespace %q: %v", namespace, err)
}

// Check if annotation exists
annotationValue, exists := ns.Annotations[common.MaxSnapshotsPerVolumeAnnotationKey]
if !exists {
log.Infof("Annotation %q not found in namespace %q, using default value %d",
common.MaxSnapshotsPerVolumeAnnotationKey, namespace, common.DefaultMaxSnapshotsPerBlockVolumeInWCP)
return common.DefaultMaxSnapshotsPerBlockVolumeInWCP, nil
}

// Parse annotation value
limit, err := strconv.Atoi(annotationValue)
if err != nil {
return 0, logger.LogNewErrorCodef(log, codes.Internal,
"failed to parse annotation %q value %q in namespace %q: %v",
common.MaxSnapshotsPerVolumeAnnotationKey, annotationValue, namespace, err)
}

// Validate limit
if limit < 0 {
return 0, logger.LogNewErrorCodef(log, codes.InvalidArgument,
"invalid snapshot limit %d in namespace %q: must be >= 0", limit, namespace)
}

// Cap to maximum allowed value
if limit > common.MaxAllowedSnapshotsPerBlockVolume {
log.Warnf("Snapshot limit %d in namespace %q exceeds maximum allowed %d, capping to %d",
limit, namespace, common.MaxAllowedSnapshotsPerBlockVolume, common.MaxAllowedSnapshotsPerBlockVolume)
return common.MaxAllowedSnapshotsPerBlockVolume, nil
}

log.Infof("Snapshot limit for namespace %q is set to %d", namespace, limit)
return limit, nil
}
Loading