diff --git a/pkg/blob/azure.go b/pkg/blob/azure.go index 97694ae8e..d566e365b 100644 --- a/pkg/blob/azure.go +++ b/pkg/blob/azure.go @@ -31,8 +31,10 @@ import ( "k8s.io/klog/v2" "k8s.io/utils/pointer" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/configloader" + azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" azure "sigs.k8s.io/cloud-provider-azure/pkg/provider" providerconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" + "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) var ( @@ -181,9 +183,10 @@ func (d *Driver) getKeyvaultToken() (authorizer autorest.Authorizer, err error) return authorizer, nil } -func (d *Driver) updateSubnetServiceEndpoints(ctx context.Context, vnetResourceGroup, vnetName, subnetName string) error { +func (d *Driver) updateSubnetServiceEndpoints(ctx context.Context, vnetResourceGroup, vnetName, subnetName string) ([]string, error) { + var vnetResourceIDs []string if d.cloud.SubnetsClient == nil { - return fmt.Errorf("SubnetsClient is nil") + return vnetResourceIDs, fmt.Errorf("SubnetsClient is nil") } if vnetResourceGroup == "" { @@ -197,56 +200,89 @@ func (d *Driver) updateSubnetServiceEndpoints(ctx context.Context, vnetResourceG if vnetName == "" { vnetName = d.cloud.VnetName } - if subnetName == "" { - subnetName = d.cloud.SubnetName - } klog.V(2).Infof("updateSubnetServiceEndpoints on vnetName: %s, subnetName: %s, location: %s", vnetName, subnetName, location) - if subnetName == "" || vnetName == "" || location == "" { - return fmt.Errorf("value of subnetName, vnetName or location is empty") + if vnetName == "" || location == "" { + return vnetResourceIDs, fmt.Errorf("vnetName or location is empty") } lockKey := vnetResourceGroup + vnetName + subnetName - d.subnetLockMap.LockEntry(lockKey) - defer d.subnetLockMap.UnlockEntry(lockKey) - - subnet, err := d.cloud.SubnetsClient.Get(ctx, vnetResourceGroup, vnetName, subnetName, "") + cache, err := d.subnetCache.Get(lockKey, azcache.CacheReadTypeDefault) if err != nil { - return fmt.Errorf("failed to get the subnet %s under vnet %s: %v", subnetName, vnetName, err) - } - endpointLocaions := []string{location} - storageServiceEndpoint := network.ServiceEndpointPropertiesFormat{ - Service: &storageService, - Locations: &endpointLocaions, - } - storageServiceExists := false - if subnet.SubnetPropertiesFormat == nil { - subnet.SubnetPropertiesFormat = &network.SubnetPropertiesFormat{} + return nil, err } - if subnet.SubnetPropertiesFormat.ServiceEndpoints == nil { - subnet.SubnetPropertiesFormat.ServiceEndpoints = &[]network.ServiceEndpointPropertiesFormat{} + if cache != nil { + vnetResourceIDs = cache.([]string) + klog.V(2).Infof("subnet %s under vnet %s in rg %s is already updated, vnetResourceIDs: %v", subnetName, vnetName, vnetResourceGroup, vnetResourceIDs) + return vnetResourceIDs, nil } - serviceEndpoints := *subnet.SubnetPropertiesFormat.ServiceEndpoints - for _, v := range serviceEndpoints { - if strings.HasPrefix(pointer.StringDeref(v.Service, ""), storageService) { - storageServiceExists = true - klog.V(4).Infof("serviceEndpoint(%s) is already in subnet(%s)", storageService, subnetName) - break + + d.subnetLockMap.LockEntry(lockKey) + defer d.subnetLockMap.UnlockEntry(lockKey) + + var subnets []network.Subnet + if subnetName != "" { + // list multiple subnets separated by comma + subnetNames := strings.Split(subnetName, ",") + for _, sn := range subnetNames { + sn = strings.TrimSpace(sn) + subnet, rerr := d.cloud.SubnetsClient.Get(ctx, vnetResourceGroup, vnetName, sn, "") + if rerr != nil { + return vnetResourceIDs, fmt.Errorf("failed to get the subnet %s under rg %s vnet %s: %v", subnetName, vnetResourceGroup, vnetName, rerr.Error()) + } + subnets = append(subnets, subnet) + } + } else { + var rerr *retry.Error + subnets, rerr = d.cloud.SubnetsClient.List(ctx, vnetResourceGroup, vnetName) + if rerr != nil { + return vnetResourceIDs, fmt.Errorf("failed to list the subnets under rg %s vnet %s: %v", vnetResourceGroup, vnetName, rerr.Error()) } } - if !storageServiceExists { - serviceEndpoints = append(serviceEndpoints, storageServiceEndpoint) - subnet.SubnetPropertiesFormat.ServiceEndpoints = &serviceEndpoints + for _, subnet := range subnets { + if subnet.Name == nil { + return vnetResourceIDs, fmt.Errorf("subnet name is nil") + } + sn := *subnet.Name + vnetResourceID := d.getSubnetResourceID(vnetResourceGroup, vnetName, sn) + klog.V(2).Infof("set vnetResourceID %s", vnetResourceID) + vnetResourceIDs = append(vnetResourceIDs, vnetResourceID) + + endpointLocaions := []string{location} + storageServiceEndpoint := network.ServiceEndpointPropertiesFormat{ + Service: &storageService, + Locations: &endpointLocaions, + } + storageServiceExists := false + if subnet.SubnetPropertiesFormat == nil { + subnet.SubnetPropertiesFormat = &network.SubnetPropertiesFormat{} + } + if subnet.SubnetPropertiesFormat.ServiceEndpoints == nil { + subnet.SubnetPropertiesFormat.ServiceEndpoints = &[]network.ServiceEndpointPropertiesFormat{} + } + serviceEndpoints := *subnet.SubnetPropertiesFormat.ServiceEndpoints + for _, v := range serviceEndpoints { + if strings.HasPrefix(pointer.StringDeref(v.Service, ""), storageService) { + storageServiceExists = true + klog.V(4).Infof("serviceEndpoint(%s) is already in subnet(%s)", storageService, sn) + break + } + } + + if !storageServiceExists { + serviceEndpoints = append(serviceEndpoints, storageServiceEndpoint) + subnet.SubnetPropertiesFormat.ServiceEndpoints = &serviceEndpoints - klog.V(2).Infof("begin to update the subnet %s under vnet %s rg %s", subnetName, vnetName, vnetResourceGroup) - if err := d.cloud.SubnetsClient.CreateOrUpdate(ctx, vnetResourceGroup, vnetName, subnetName, subnet); err != nil { - return fmt.Errorf("failed to update the subnet %s under vnet %s: %v", subnetName, vnetName, err) + klog.V(2).Infof("begin to update the subnet %s under vnet %s in rg %s", sn, vnetName, vnetResourceGroup) + if err := d.cloud.SubnetsClient.CreateOrUpdate(ctx, vnetResourceGroup, vnetName, sn, subnet); err != nil { + return vnetResourceIDs, fmt.Errorf("failed to update the subnet %s under vnet %s: %v", sn, vnetName, err) + } } - klog.V(2).Infof("serviceEndpoint(%s) is appended in subnet(%s)", storageService, subnetName) } - - return nil + // cache the subnet update + d.subnetCache.Set(lockKey, vnetResourceIDs) + return vnetResourceIDs, nil } func (d *Driver) getStorageEndPointSuffix() string { diff --git a/pkg/blob/azure_test.go b/pkg/blob/azure_test.go index e1c67979c..4a701f774 100644 --- a/pkg/blob/azure_test.go +++ b/pkg/blob/azure_test.go @@ -32,12 +32,11 @@ import ( "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "k8s.io/client-go/kubernetes" + "k8s.io/utils/pointer" "sigs.k8s.io/blob-csi-driver/pkg/util" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient/mocksubnetclient" azureprovider "sigs.k8s.io/cloud-provider-azure/pkg/provider" - - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) // TestGetCloudProvider tests the func getCloudProvider(). @@ -328,25 +327,14 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) { testFunc func(t *testing.T) }{ { - name: "[fail] no subnet", - testFunc: func(t *testing.T) { - retErr := retry.NewError(false, fmt.Errorf("the subnet does not exist")) - mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.Subnet{}, retErr).Times(1) - expectedErr := fmt.Errorf("failed to get the subnet %s under vnet %s: %v", config.SubnetName, config.VnetName, retErr) - err := d.updateSubnetServiceEndpoints(ctx, "", "", "") - if !reflect.DeepEqual(err, expectedErr) { - t.Errorf("Unexpected error: %v", err) - } - }, - }, - { - name: "[success] subnetPropertiesFormat is nil", + name: "[fail] subnet name is nil", testFunc: func(t *testing.T) { mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.Subnet{}, nil).Times(1) mockSubnetClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) - err := d.updateSubnetServiceEndpoints(ctx, "", "", "") - if !reflect.DeepEqual(err, nil) { + _, err := d.updateSubnetServiceEndpoints(ctx, "", "", "subnetname") + expectedErr := fmt.Errorf("subnet name is nil") + if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } }, @@ -356,12 +344,11 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) { testFunc: func(t *testing.T) { fakeSubnet := network.Subnet{ SubnetPropertiesFormat: &network.SubnetPropertiesFormat{}, + Name: pointer.String("subnetName"), } mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).Times(1) - mockSubnetClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) - - err := d.updateSubnetServiceEndpoints(ctx, "", "", "") + _, err := d.updateSubnetServiceEndpoints(ctx, "", "", "subnetname") if !reflect.DeepEqual(err, nil) { t.Errorf("Unexpected error: %v", err) } @@ -374,12 +361,12 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) { SubnetPropertiesFormat: &network.SubnetPropertiesFormat{ ServiceEndpoints: &[]network.ServiceEndpointPropertiesFormat{}, }, + Name: pointer.String("subnetName"), } - mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).Times(1) - mockSubnetClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) + mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).AnyTimes() - err := d.updateSubnetServiceEndpoints(ctx, "", "", "") + _, err := d.updateSubnetServiceEndpoints(ctx, "", "", "subnetname") if !reflect.DeepEqual(err, nil) { t.Errorf("Unexpected error: %v", err) } @@ -396,11 +383,12 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) { }, }, }, + Name: pointer.String("subnetName"), } - mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).Times(1) + mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).AnyTimes() - err := d.updateSubnetServiceEndpoints(ctx, "", "", "") + _, err := d.updateSubnetServiceEndpoints(ctx, "", "", "subnetname") if !reflect.DeepEqual(err, nil) { t.Errorf("Unexpected error: %v", err) } @@ -411,7 +399,7 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) { testFunc: func(t *testing.T) { d.cloud.SubnetsClient = nil expectedErr := fmt.Errorf("SubnetsClient is nil") - err := d.updateSubnetServiceEndpoints(ctx, "", "", "") + _, err := d.updateSubnetServiceEndpoints(ctx, "", "", "") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } diff --git a/pkg/blob/blob.go b/pkg/blob/blob.go index bf60fed70..ff3f27c2c 100644 --- a/pkg/blob/blob.go +++ b/pkg/blob/blob.go @@ -248,6 +248,8 @@ type Driver struct { volStatsCache azcache.Resource // a timed cache storing account which should use sastoken for azcopy based volume cloning azcopySasTokenCache azcache.Resource + // a timed cache storing subnet operations + subnetCache azcache.Resource // sas expiry time for azcopy in volume clone sasTokenExpirationMinutes int // timeout in minutes for waiting for azcopy to finish @@ -306,6 +308,10 @@ func NewDriver(options *DriverOptions, kubeClient kubernetes.Interface, cloud *p if d.volStatsCache, err = azcache.NewTimedCache(time.Duration(options.VolStatsCacheExpireInMinutes)*time.Minute, getter, false); err != nil { klog.Fatalf("%v", err) } + if d.subnetCache, err = azcache.NewTimedCache(10*time.Minute, getter, false); err != nil { + klog.Fatalf("%v", err) + } + d.mounter = &mount.SafeFormatAndMount{ Interface: mount.New(""), Exec: utilexec.New(), diff --git a/pkg/blob/blob_test.go b/pkg/blob/blob_test.go index 95d797de5..6ee76bb97 100644 --- a/pkg/blob/blob_test.go +++ b/pkg/blob/blob_test.go @@ -97,6 +97,7 @@ func TestNewDriver(t *testing.T) { fakedriver.dataPlaneAPIVolCache = driver.dataPlaneAPIVolCache fakedriver.azcopySasTokenCache = driver.azcopySasTokenCache fakedriver.volStatsCache = driver.volStatsCache + fakedriver.subnetCache = driver.subnetCache fakedriver.cloud = driver.cloud assert.Equal(t, driver, fakedriver) } diff --git a/pkg/blob/controllerserver.go b/pkg/blob/controllerserver.go index caa7c2e92..737b92730 100644 --- a/pkg/blob/controllerserver.go +++ b/pkg/blob/controllerserver.go @@ -279,15 +279,9 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) storeAccountKey = false if !pointer.BoolDeref(createPrivateEndpoint, false) { // set VirtualNetworkResourceIDs for storage account firewall setting - subnets := strings.Split(subnetName, ",") - for _, subnet := range subnets { - subnet = strings.TrimSpace(subnet) - vnetResourceID := d.getSubnetResourceID(vnetResourceGroup, vnetName, subnet) - klog.V(2).Infof("set vnetResourceID(%s) for NFS protocol", vnetResourceID) - vnetResourceIDs = append(vnetResourceIDs, vnetResourceID) - if err := d.updateSubnetServiceEndpoints(ctx, vnetResourceGroup, vnetName, subnet); err != nil { - return nil, status.Errorf(codes.Internal, "update service endpoints failed with error: %v", err) - } + var err error + if vnetResourceIDs, err = d.updateSubnetServiceEndpoints(ctx, vnetResourceGroup, vnetName, subnetName); err != nil { + return nil, status.Errorf(codes.Internal, "update service endpoints failed with error: %v", err) } } }