From d069037b061e0a0902ebc399128d0e7e9a6d7c44 Mon Sep 17 00:00:00 2001 From: Priyansh Choudhary Date: Tue, 14 Oct 2025 14:21:09 +0530 Subject: [PATCH] fix: create volume failure when requested pvc size is smaller than snapshot disk size refactor: streamline snapshot handling in CreateVolume and update fake driver methods fix: adjust disk size handling for snapshot creation and improve size validation logic fix: increase max snapshot size difference allowed from 1 GiB to 50 GiB refactor: simplify snapshot size and SKU retrieval functions --- pkg/azuredisk/azure_managedDiskController.go | 19 +- pkg/azuredisk/controllerserver.go | 81 ++++- pkg/azuredisk/controllerserver_test.go | 304 ++++++++++++++++++- pkg/azuredisk/fake_azuredisk.go | 1 + 4 files changed, 396 insertions(+), 9 deletions(-) diff --git a/pkg/azuredisk/azure_managedDiskController.go b/pkg/azuredisk/azure_managedDiskController.go index e9121ebc65..a750786f5c 100644 --- a/pkg/azuredisk/azure_managedDiskController.go +++ b/pkg/azuredisk/azure_managedDiskController.go @@ -34,6 +34,7 @@ import ( volumehelpers "k8s.io/cloud-provider/volume/helpers" "k8s.io/klog/v2" "k8s.io/utils/ptr" + csidriverconsts "sigs.k8s.io/azuredisk-csi-driver/pkg/azureconstants" azureconsts "sigs.k8s.io/azuredisk-csi-driver/pkg/azureconstants" "sigs.k8s.io/azuredisk-csi-driver/pkg/azureutils" @@ -153,10 +154,20 @@ func (c *ManagedDiskController) CreateManagedDisk(ctx context.Context, options * if err != nil { return "", err } - diskProperties := armcompute.DiskProperties{ - DiskSizeGB: &diskSizeGB, - CreationData: &creationData, - BurstingEnabled: options.BurstingEnabled, + + diskProperties := armcompute.DiskProperties{} + // when creating from snapshot, and diskSizeGB is 0, let disk RP calculate size from snapshot bytes size. + if diskSizeGB == 0 && options.SourceType == csidriverconsts.SourceSnapshot { + diskProperties = armcompute.DiskProperties{ + CreationData: &creationData, + BurstingEnabled: options.BurstingEnabled, + } + } else { + diskProperties = armcompute.DiskProperties{ + CreationData: &creationData, + BurstingEnabled: options.BurstingEnabled, + DiskSizeGB: &diskSizeGB, + } } if options.PublicNetworkAccess != "" { diff --git a/pkg/azuredisk/controllerserver.go b/pkg/azuredisk/controllerserver.go index 5d92010492..984284b1cf 100644 --- a/pkg/azuredisk/controllerserver.go +++ b/pkg/azuredisk/controllerserver.go @@ -49,10 +49,11 @@ import ( ) const ( - waitForSnapshotReadyInterval = 5 * time.Second - waitForSnapshotReadyTimeout = 10 * time.Minute - maxErrMsgLength = 990 - checkDiskLunThrottleLatency = 1 * time.Second + waitForSnapshotReadyInterval = 5 * time.Second + waitForSnapshotReadyTimeout = 10 * time.Minute + maxErrMsgLength = 990 + checkDiskLunThrottleLatency = 1 * time.Second + maxSnapshotSizeDifferenceAllowed = 50 // in GiB ) // listVolumeStatus explains the return status of `listVolumesByResourceGroup` @@ -107,6 +108,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) capacityBytes := req.GetCapacityRange().GetRequiredBytes() volSizeBytes := int64(capacityBytes) + requestSizeToBeSupplied := true requestGiB := int(volumehelper.RoundUpGiB(volSizeBytes)) if diskParams.PerformancePlus != nil && *diskParams.PerformancePlus && requestGiB < consts.PerformancePlusMinimumDiskSizeGiB { @@ -225,6 +227,26 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) }, }, } + + // Get snapshot once and extract both SKU and disk size info + snapshot, err := d.getSnapshot(ctx, sourceID) + if err == nil { + // fetch snapshot info and compare size with requested size + // when snapshot size is larger than requested size, do not supply the size in the request + // to allow Azure to create a disk with exact snapshot size in bytes. + diskSizeInBytes, err := getDiskSizeInBytesFromSnapshot(snapshot) + if err == nil { + requestedGiBfromSnapshot := int(volumehelper.RoundUpGiB(diskSizeInBytes)) + differenceSize := requestedGiBfromSnapshot - requestGiB + if requestedGiBfromSnapshot > requestGiB && differenceSize <= maxSnapshotSizeDifferenceAllowed { + klog.V(4).Infof("snapshot size (%d GiB) is larger than requested size (%d GiB) but difference (%d GiB) is within the allowed limit (%d GiB), will not supply the size in the create disk request", requestedGiBfromSnapshot, requestGiB, differenceSize, maxSnapshotSizeDifferenceAllowed) + requestSizeToBeSupplied = false + } + } + } else { + return nil, status.Errorf(codes.NotFound, "%v", err) + } + metricsRequest = "controller_create_volume_from_snapshot" } else { sourceID = content.GetVolume().GetVolumeId() @@ -295,6 +317,11 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } diskParams.VolumeContext[consts.RequestedSizeGib] = strconv.Itoa(requestGiB) + + if !requestSizeToBeSupplied && sourceType == consts.SourceSnapshot { + requestGiB = 0 + } + volumeOptions := &ManagedDiskOptions{ AvailabilityZone: diskZone, BurstingEnabled: diskParams.EnableBursting, @@ -1336,3 +1363,49 @@ func (d *Driver) GetSourceDiskSize(ctx context.Context, subsID, resourceGroup, d } return (*result.Properties).DiskSizeGB, result, nil } + +// getSnapshot retrieves the Snapshot and returns the Snapshot or the error if any error occurs +func (d *Driver) getSnapshot(ctx context.Context, sourceID string) (*armcompute.Snapshot, error) { + subsID, resourceGroup, snapshotName, err := azureutils.GetInfoFromURI(sourceID) + if err != nil { + klog.Warningf("could not get subscription id, resource group from snapshot uri (%s) with error(%v)", sourceID, err) + return nil, err + } + snapClient, err := d.clientFactory.GetSnapshotClientForSub(subsID) + if err != nil { + klog.Warningf("could not get snapshot client for subscription(%s) with error(%v)", subsID, err) + return nil, err + } + snapshotRetrieved, err := snapClient.Get(ctx, resourceGroup, snapshotName) + if err != nil { + klog.Warningf("get snapshot %s from rg(%s) error: %v", snapshotName, resourceGroup, err) + return nil, err + } + return snapshotRetrieved, nil +} + +// getSnapshotSKU retrieves the SKU of the snapshot and returns the SKU or if any error occurs +func getSnapshotSKUFromSnapshot(computeSnapshot *armcompute.Snapshot) (string, error) { + if computeSnapshot == nil { + klog.Warningf("Snapshot is nil") + return "", status.Error(codes.NotFound, "Snapshot is nil") + } + if computeSnapshot.SKU == nil || computeSnapshot.SKU.Name == nil { + klog.Warningf("Snapshot or Snapshot Properties SKU not found for snapshot") + return "", status.Error(codes.NotFound, "Snapshot SKU property not found") + } + return string(*computeSnapshot.SKU.Name), nil +} + +// getDiskSizeInBytes retrieves the size of the disk and returns the size or if any error occurs +func getDiskSizeInBytesFromSnapshot(computeSnapshot *armcompute.Snapshot) (int64, error) { + if computeSnapshot == nil { + klog.Warningf("Snapshot is nil") + return 0, status.Error(codes.NotFound, "Snapshot is nil") + } + if computeSnapshot.Properties == nil || computeSnapshot.Properties.DiskSizeBytes == nil { + klog.Warningf("Snapshot or Snapshot Properties.DiskSizeBytes not found for snapshot") + return 0, status.Error(codes.NotFound, "Snapshot size not found") + } + return *computeSnapshot.Properties.DiskSizeBytes, nil +} diff --git a/pkg/azuredisk/controllerserver_test.go b/pkg/azuredisk/controllerserver_test.go index 09d30f2de4..81b450c754 100644 --- a/pkg/azuredisk/controllerserver_test.go +++ b/pkg/azuredisk/controllerserver_test.go @@ -30,6 +30,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -371,7 +372,9 @@ func TestCreateVolume(t *testing.T) { VolumeContentSource: &volumecontensource, } disk := &armcompute.Disk{ - Properties: &armcompute.DiskProperties{}, + Properties: &armcompute.DiskProperties{ + DiskSizeBytes: ptr.To(int64(1073741824)), // 1GB in bytes + }, } diskClient := mock_diskclient.NewMockInterface(cntl) d.getClientFactory().(*mock_azclient.MockClientFactory).EXPECT().GetDiskClientForSub(gomock.Any()).Return(diskClient, nil).AnyTimes() @@ -3590,6 +3593,305 @@ func TestGetSourceDiskSize(t *testing.T) { } } +func TestGetSnapshotSKU(t *testing.T) { + type testCase struct { + name string + snapshotURI string + setupMocks func(factory *mock_azclient.MockClientFactory, snap *mock_snapshotclient.MockInterface) + expectedSKU string + expectFactory bool + expectErrSubstr string + expectGRPCCode codes.Code + } + tests := []testCase{ + { + name: "success premium", + snapshotURI: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/snapshots/snap", + expectFactory: true, + expectedSKU: string(armcompute.SnapshotStorageAccountTypesPremiumLRS), + setupMocks: func(f *mock_azclient.MockClientFactory, s *mock_snapshotclient.MockInterface) { + f.EXPECT().GetSnapshotClientForSub("sub").Return(s, nil) + s.EXPECT(). + Get(gomock.Any(), "rg", "snap"). + Return(&armcompute.Snapshot{ + SKU: &armcompute.SnapshotSKU{ + Name: to.Ptr(armcompute.SnapshotStorageAccountTypesPremiumLRS), + }, + }, nil) + }, + }, + { + name: "bad URI", + snapshotURI: "bad-uri", + expectErrSubstr: "invalid URI", + expectGRPCCode: codes.NotFound, + setupMocks: func(_ *mock_azclient.MockClientFactory, _ *mock_snapshotclient.MockInterface) {}, + }, + { + name: "factory error -> empty string", + snapshotURI: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/snapshots/snap", + expectFactory: true, + expectedSKU: "", + expectErrSubstr: "factory error", + setupMocks: func(f *mock_azclient.MockClientFactory, _ *mock_snapshotclient.MockInterface) { + f.EXPECT().GetSnapshotClientForSub("sub").Return(nil, fmt.Errorf("factory error")) + }, + }, + { + name: "get error -> empty string", + snapshotURI: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/snapshots/snap", + expectFactory: true, + expectedSKU: "", + expectErrSubstr: "get error", + setupMocks: func(f *mock_azclient.MockClientFactory, s *mock_snapshotclient.MockInterface) { + f.EXPECT().GetSnapshotClientForSub("sub").Return(s, nil) + s.EXPECT().Get(gomock.Any(), "rg", "snap").Return(nil, fmt.Errorf("get error")) + }, + }, + { + name: "nil snapshot result -> empty string", + snapshotURI: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/snapshots/snap", + expectFactory: true, + expectedSKU: "", + expectErrSubstr: "Snapshot is nil", + setupMocks: func(f *mock_azclient.MockClientFactory, s *mock_snapshotclient.MockInterface) { + f.EXPECT().GetSnapshotClientForSub("sub").Return(s, nil) + s.EXPECT().Get(gomock.Any(), "rg", "snap").Return(nil, nil) + }, + }, + { + name: "nil SKU struct -> empty string", + snapshotURI: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/snapshots/snap", + expectFactory: true, + expectedSKU: "", + expectErrSubstr: "Snapshot SKU property not found", + setupMocks: func(f *mock_azclient.MockClientFactory, s *mock_snapshotclient.MockInterface) { + f.EXPECT().GetSnapshotClientForSub("sub").Return(s, nil) + s.EXPECT().Get(gomock.Any(), "rg", "snap").Return(&armcompute.Snapshot{}, nil) + }, + }, + { + name: "nil SKU name -> empty string", + snapshotURI: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/snapshots/snap", + expectFactory: true, + expectedSKU: "", + expectErrSubstr: "Snapshot SKU property not found", + setupMocks: func(f *mock_azclient.MockClientFactory, s *mock_snapshotclient.MockInterface) { + f.EXPECT().GetSnapshotClientForSub("sub").Return(s, nil) + s.EXPECT().Get(gomock.Any(), "rg", "snap").Return(&armcompute.Snapshot{ + SKU: &armcompute.SnapshotSKU{Name: nil}, + }, nil) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + d, _ := NewFakeDriver(ctrl) + + factory, ok := d.getClientFactory().(*mock_azclient.MockClientFactory) + if !ok { + t.Fatalf("clientFactory is not a *mock_azclient.MockClientFactory") + } + snapClient := mock_snapshotclient.NewMockInterface(ctrl) + + if tc.setupMocks != nil { + tc.setupMocks(factory, snapClient) + } + + invoke := func() (string, error) { + snapshot, err := d.getSnapshot(context.Background(), tc.snapshotURI) + if err != nil { + return "", err + } + return getSnapshotSKUFromSnapshot(snapshot) + } + + sku, err := invoke() + + if tc.expectErrSubstr != "" { + require.Error(t, err) + require.Empty(t, sku) + require.Contains(t, err.Error(), tc.expectErrSubstr) + if tc.expectGRPCCode != 0 { + if st, ok := status.FromError(err); ok { + require.Equal(t, tc.expectGRPCCode, st.Code()) + } + } + } else { + require.NoError(t, err) + require.NotNil(t, sku) + require.Equal(t, tc.expectedSKU, sku) + } + }) + } +} + +func TestGetSnapshotSKUFromSnapshot(t *testing.T) { + type testCase struct { + name string + snapshot *armcompute.Snapshot + expectedSKU string + expectErrSubstr string + expectGRPCCode codes.Code + } + tests := []testCase{ + { + name: "success - premium LRS", + expectedSKU: string(armcompute.SnapshotStorageAccountTypesPremiumLRS), + snapshot: &armcompute.Snapshot{ + SKU: &armcompute.SnapshotSKU{ + Name: to.Ptr(armcompute.SnapshotStorageAccountTypesPremiumLRS), + }, + }, + }, + { + name: "success - standard LRS", + expectedSKU: string(armcompute.SnapshotStorageAccountTypesStandardLRS), + snapshot: &armcompute.Snapshot{ + SKU: &armcompute.SnapshotSKU{ + Name: to.Ptr(armcompute.SnapshotStorageAccountTypesStandardLRS), + }, + }, + }, + { + name: "nil snapshot", + snapshot: nil, + expectedSKU: "", + expectErrSubstr: "Snapshot is nil", + expectGRPCCode: codes.NotFound, + }, + { + name: "nil SKU", + snapshot: &armcompute.Snapshot{}, + expectedSKU: "", + expectErrSubstr: "Snapshot SKU property not found", + expectGRPCCode: codes.NotFound, + }, + { + name: "nil SKU Name", + snapshot: &armcompute.Snapshot{ + SKU: &armcompute.SnapshotSKU{Name: nil}, + }, + expectedSKU: "", + expectErrSubstr: "Snapshot SKU property not found", + expectGRPCCode: codes.NotFound, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + sku, err := getSnapshotSKUFromSnapshot(tc.snapshot) + + if tc.expectErrSubstr != "" { + require.Error(t, err) + require.Empty(t, sku) + require.Contains(t, err.Error(), tc.expectErrSubstr) + if tc.expectGRPCCode != 0 { + if st, ok := status.FromError(err); ok { + require.Equal(t, tc.expectGRPCCode, st.Code()) + } + } + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedSKU, sku) + } + }) + } +} + +func TestGetDiskSizeInBytesFromSnapshot(t *testing.T) { + type testCase struct { + name string + snapshot *armcompute.Snapshot + expectedSize int64 + expectErrSubstr string + expectGRPCCode codes.Code + } + tests := []testCase{ + { + name: "success - 100GB disk", + expectedSize: 107374182400, // 100GB in bytes + snapshot: &armcompute.Snapshot{ + Properties: &armcompute.SnapshotProperties{ + DiskSizeBytes: to.Ptr(int64(107374182400)), + }, + }, + }, + { + name: "success - 1TB disk", + expectedSize: 1099511627776, // 1TB in bytes + snapshot: &armcompute.Snapshot{ + Properties: &armcompute.SnapshotProperties{ + DiskSizeBytes: to.Ptr(int64(1099511627776)), + }, + }, + }, + { + name: "success - small disk", + expectedSize: 4294967296, // 4GB in bytes + snapshot: &armcompute.Snapshot{ + Properties: &armcompute.SnapshotProperties{ + DiskSizeBytes: to.Ptr(int64(4294967296)), + }, + }, + }, + { + name: "nil snapshot", + snapshot: nil, + expectedSize: 0, + expectErrSubstr: "Snapshot is nil", + expectGRPCCode: codes.NotFound, + }, + { + name: "nil Properties", + snapshot: &armcompute.Snapshot{}, + expectedSize: 0, + expectErrSubstr: "Snapshot size not found", + expectGRPCCode: codes.NotFound, + }, + { + name: "nil DiskSizeBytes", + snapshot: &armcompute.Snapshot{ + Properties: &armcompute.SnapshotProperties{ + DiskSizeBytes: nil, + }, + }, + expectedSize: 0, + expectErrSubstr: "Snapshot size not found", + expectGRPCCode: codes.NotFound, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + size, err := getDiskSizeInBytesFromSnapshot(tc.snapshot) + + if tc.expectErrSubstr != "" { + require.Error(t, err) + require.Equal(t, int64(0), size) + require.Contains(t, err.Error(), tc.expectErrSubstr) + if tc.expectGRPCCode != 0 { + if st, ok := status.FromError(err); ok { + require.Equal(t, tc.expectGRPCCode, st.Code()) + } + } + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedSize, size) + } + }) + } +} + func getFakeDriverWithKubeClient(ctrl *gomock.Controller) FakeDriver { d, _ := NewFakeDriver(ctrl) diff --git a/pkg/azuredisk/fake_azuredisk.go b/pkg/azuredisk/fake_azuredisk.go index ac7dacff6a..f2ef64faa8 100644 --- a/pkg/azuredisk/fake_azuredisk.go +++ b/pkg/azuredisk/fake_azuredisk.go @@ -92,6 +92,7 @@ type FakeDriver interface { checkDiskExists(ctx context.Context, diskURI string) (*armcompute.Disk, error) waitForSnapshotReady(context.Context, string, string, string, time.Duration, time.Duration) error getSnapshotByID(context.Context, string, string, string, string) (*csi.Snapshot, error) + getSnapshot(context.Context, string) (*armcompute.Snapshot, error) ensureMountPoint(string) (bool, error) ensureBlockTargetFile(string) error getDevicePathWithLUN(lunStr string) (string, error)