Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 74 additions & 38 deletions pkg/blob/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ import (
"k8s.io/klog/v2"
"k8s.io/utils/pointer"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/configloader"
azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache"
azure "sigs.k8s.io/cloud-provider-azure/pkg/provider"
providerconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config"
"sigs.k8s.io/cloud-provider-azure/pkg/retry"
)

var (
Expand Down Expand Up @@ -181,9 +183,10 @@ func (d *Driver) getKeyvaultToken() (authorizer autorest.Authorizer, err error)
return authorizer, nil
}

func (d *Driver) updateSubnetServiceEndpoints(ctx context.Context, vnetResourceGroup, vnetName, subnetName string) error {
func (d *Driver) updateSubnetServiceEndpoints(ctx context.Context, vnetResourceGroup, vnetName, subnetName string) ([]string, error) {
var vnetResourceIDs []string
if d.cloud.SubnetsClient == nil {
return fmt.Errorf("SubnetsClient is nil")
return vnetResourceIDs, fmt.Errorf("SubnetsClient is nil")
}

if vnetResourceGroup == "" {
Expand All @@ -197,56 +200,89 @@ func (d *Driver) updateSubnetServiceEndpoints(ctx context.Context, vnetResourceG
if vnetName == "" {
vnetName = d.cloud.VnetName
}
if subnetName == "" {
subnetName = d.cloud.SubnetName
}

klog.V(2).Infof("updateSubnetServiceEndpoints on vnetName: %s, subnetName: %s, location: %s", vnetName, subnetName, location)
if subnetName == "" || vnetName == "" || location == "" {
return fmt.Errorf("value of subnetName, vnetName or location is empty")
if vnetName == "" || location == "" {
return vnetResourceIDs, fmt.Errorf("vnetName or location is empty")
}

lockKey := vnetResourceGroup + vnetName + subnetName
d.subnetLockMap.LockEntry(lockKey)
defer d.subnetLockMap.UnlockEntry(lockKey)

subnet, err := d.cloud.SubnetsClient.Get(ctx, vnetResourceGroup, vnetName, subnetName, "")
cache, err := d.subnetCache.Get(lockKey, azcache.CacheReadTypeDefault)
if err != nil {
return fmt.Errorf("failed to get the subnet %s under vnet %s: %v", subnetName, vnetName, err)
}
endpointLocaions := []string{location}
storageServiceEndpoint := network.ServiceEndpointPropertiesFormat{
Service: &storageService,
Locations: &endpointLocaions,
}
storageServiceExists := false
if subnet.SubnetPropertiesFormat == nil {
subnet.SubnetPropertiesFormat = &network.SubnetPropertiesFormat{}
return nil, err
}
if subnet.SubnetPropertiesFormat.ServiceEndpoints == nil {
subnet.SubnetPropertiesFormat.ServiceEndpoints = &[]network.ServiceEndpointPropertiesFormat{}
if cache != nil {
vnetResourceIDs = cache.([]string)
klog.V(2).Infof("subnet %s under vnet %s in rg %s is already updated, vnetResourceIDs: %v", subnetName, vnetName, vnetResourceGroup, vnetResourceIDs)
return vnetResourceIDs, nil
}
serviceEndpoints := *subnet.SubnetPropertiesFormat.ServiceEndpoints
for _, v := range serviceEndpoints {
if strings.HasPrefix(pointer.StringDeref(v.Service, ""), storageService) {
storageServiceExists = true
klog.V(4).Infof("serviceEndpoint(%s) is already in subnet(%s)", storageService, subnetName)
break

d.subnetLockMap.LockEntry(lockKey)
defer d.subnetLockMap.UnlockEntry(lockKey)

var subnets []network.Subnet
if subnetName != "" {
// list multiple subnets separated by comma
subnetNames := strings.Split(subnetName, ",")
for _, sn := range subnetNames {
sn = strings.TrimSpace(sn)
subnet, rerr := d.cloud.SubnetsClient.Get(ctx, vnetResourceGroup, vnetName, sn, "")
if rerr != nil {
return vnetResourceIDs, fmt.Errorf("failed to get the subnet %s under rg %s vnet %s: %v", subnetName, vnetResourceGroup, vnetName, rerr.Error())
}
subnets = append(subnets, subnet)
}
} else {
var rerr *retry.Error
subnets, rerr = d.cloud.SubnetsClient.List(ctx, vnetResourceGroup, vnetName)
if rerr != nil {
return vnetResourceIDs, fmt.Errorf("failed to list the subnets under rg %s vnet %s: %v", vnetResourceGroup, vnetName, rerr.Error())
}
}

if !storageServiceExists {
serviceEndpoints = append(serviceEndpoints, storageServiceEndpoint)
subnet.SubnetPropertiesFormat.ServiceEndpoints = &serviceEndpoints
for _, subnet := range subnets {
if subnet.Name == nil {
return vnetResourceIDs, fmt.Errorf("subnet name is nil")
}
sn := *subnet.Name
vnetResourceID := d.getSubnetResourceID(vnetResourceGroup, vnetName, sn)
klog.V(2).Infof("set vnetResourceID %s", vnetResourceID)
vnetResourceIDs = append(vnetResourceIDs, vnetResourceID)

endpointLocaions := []string{location}
storageServiceEndpoint := network.ServiceEndpointPropertiesFormat{
Service: &storageService,
Locations: &endpointLocaions,
}
storageServiceExists := false
if subnet.SubnetPropertiesFormat == nil {
subnet.SubnetPropertiesFormat = &network.SubnetPropertiesFormat{}
}
if subnet.SubnetPropertiesFormat.ServiceEndpoints == nil {
subnet.SubnetPropertiesFormat.ServiceEndpoints = &[]network.ServiceEndpointPropertiesFormat{}
}
serviceEndpoints := *subnet.SubnetPropertiesFormat.ServiceEndpoints
for _, v := range serviceEndpoints {
if strings.HasPrefix(pointer.StringDeref(v.Service, ""), storageService) {
storageServiceExists = true
klog.V(4).Infof("serviceEndpoint(%s) is already in subnet(%s)", storageService, sn)
break
}
}

if !storageServiceExists {
serviceEndpoints = append(serviceEndpoints, storageServiceEndpoint)
subnet.SubnetPropertiesFormat.ServiceEndpoints = &serviceEndpoints

klog.V(2).Infof("begin to update the subnet %s under vnet %s rg %s", subnetName, vnetName, vnetResourceGroup)
if err := d.cloud.SubnetsClient.CreateOrUpdate(ctx, vnetResourceGroup, vnetName, subnetName, subnet); err != nil {
return fmt.Errorf("failed to update the subnet %s under vnet %s: %v", subnetName, vnetName, err)
klog.V(2).Infof("begin to update the subnet %s under vnet %s in rg %s", sn, vnetName, vnetResourceGroup)
if err := d.cloud.SubnetsClient.CreateOrUpdate(ctx, vnetResourceGroup, vnetName, sn, subnet); err != nil {
return vnetResourceIDs, fmt.Errorf("failed to update the subnet %s under vnet %s: %v", sn, vnetName, err)
}
}
klog.V(2).Infof("serviceEndpoint(%s) is appended in subnet(%s)", storageService, subnetName)
}

return nil
// cache the subnet update
d.subnetCache.Set(lockKey, vnetResourceIDs)
return vnetResourceIDs, nil
}

func (d *Driver) getStorageEndPointSuffix() string {
Expand Down
40 changes: 14 additions & 26 deletions pkg/blob/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ import (
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"k8s.io/client-go/kubernetes"
"k8s.io/utils/pointer"

"sigs.k8s.io/blob-csi-driver/pkg/util"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient/mocksubnetclient"
azureprovider "sigs.k8s.io/cloud-provider-azure/pkg/provider"

"sigs.k8s.io/cloud-provider-azure/pkg/retry"
)

// TestGetCloudProvider tests the func getCloudProvider().
Expand Down Expand Up @@ -328,25 +327,14 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) {
testFunc func(t *testing.T)
}{
{
name: "[fail] no subnet",
testFunc: func(t *testing.T) {
retErr := retry.NewError(false, fmt.Errorf("the subnet does not exist"))
mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.Subnet{}, retErr).Times(1)
expectedErr := fmt.Errorf("failed to get the subnet %s under vnet %s: %v", config.SubnetName, config.VnetName, retErr)
err := d.updateSubnetServiceEndpoints(ctx, "", "", "")
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
},
},
{
name: "[success] subnetPropertiesFormat is nil",
name: "[fail] subnet name is nil",
testFunc: func(t *testing.T) {
mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.Subnet{}, nil).Times(1)
mockSubnetClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1)

err := d.updateSubnetServiceEndpoints(ctx, "", "", "")
if !reflect.DeepEqual(err, nil) {
_, err := d.updateSubnetServiceEndpoints(ctx, "", "", "subnetname")
expectedErr := fmt.Errorf("subnet name is nil")
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
},
Expand All @@ -356,12 +344,11 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) {
testFunc: func(t *testing.T) {
fakeSubnet := network.Subnet{
SubnetPropertiesFormat: &network.SubnetPropertiesFormat{},
Name: pointer.String("subnetName"),
}

mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).Times(1)
mockSubnetClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1)

err := d.updateSubnetServiceEndpoints(ctx, "", "", "")
_, err := d.updateSubnetServiceEndpoints(ctx, "", "", "subnetname")
if !reflect.DeepEqual(err, nil) {
t.Errorf("Unexpected error: %v", err)
}
Expand All @@ -374,12 +361,12 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) {
SubnetPropertiesFormat: &network.SubnetPropertiesFormat{
ServiceEndpoints: &[]network.ServiceEndpointPropertiesFormat{},
},
Name: pointer.String("subnetName"),
}

mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).Times(1)
mockSubnetClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1)
mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).AnyTimes()

err := d.updateSubnetServiceEndpoints(ctx, "", "", "")
_, err := d.updateSubnetServiceEndpoints(ctx, "", "", "subnetname")
if !reflect.DeepEqual(err, nil) {
t.Errorf("Unexpected error: %v", err)
}
Expand All @@ -396,11 +383,12 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) {
},
},
},
Name: pointer.String("subnetName"),
}

mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).Times(1)
mockSubnetClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeSubnet, nil).AnyTimes()

err := d.updateSubnetServiceEndpoints(ctx, "", "", "")
_, err := d.updateSubnetServiceEndpoints(ctx, "", "", "subnetname")
if !reflect.DeepEqual(err, nil) {
t.Errorf("Unexpected error: %v", err)
}
Expand All @@ -411,7 +399,7 @@ func TestUpdateSubnetServiceEndpoints(t *testing.T) {
testFunc: func(t *testing.T) {
d.cloud.SubnetsClient = nil
expectedErr := fmt.Errorf("SubnetsClient is nil")
err := d.updateSubnetServiceEndpoints(ctx, "", "", "")
_, err := d.updateSubnetServiceEndpoints(ctx, "", "", "")
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/blob/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ type Driver struct {
volStatsCache azcache.Resource
// a timed cache storing account which should use sastoken for azcopy based volume cloning
azcopySasTokenCache azcache.Resource
// a timed cache storing subnet operations
subnetCache azcache.Resource
// sas expiry time for azcopy in volume clone
sasTokenExpirationMinutes int
// timeout in minutes for waiting for azcopy to finish
Expand Down Expand Up @@ -306,6 +308,10 @@ func NewDriver(options *DriverOptions, kubeClient kubernetes.Interface, cloud *p
if d.volStatsCache, err = azcache.NewTimedCache(time.Duration(options.VolStatsCacheExpireInMinutes)*time.Minute, getter, false); err != nil {
klog.Fatalf("%v", err)
}
if d.subnetCache, err = azcache.NewTimedCache(10*time.Minute, getter, false); err != nil {
klog.Fatalf("%v", err)
}

d.mounter = &mount.SafeFormatAndMount{
Interface: mount.New(""),
Exec: utilexec.New(),
Expand Down
1 change: 1 addition & 0 deletions pkg/blob/blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func TestNewDriver(t *testing.T) {
fakedriver.dataPlaneAPIVolCache = driver.dataPlaneAPIVolCache
fakedriver.azcopySasTokenCache = driver.azcopySasTokenCache
fakedriver.volStatsCache = driver.volStatsCache
fakedriver.subnetCache = driver.subnetCache
fakedriver.cloud = driver.cloud
assert.Equal(t, driver, fakedriver)
}
Expand Down
12 changes: 3 additions & 9 deletions pkg/blob/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,9 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
storeAccountKey = false
if !pointer.BoolDeref(createPrivateEndpoint, false) {
// set VirtualNetworkResourceIDs for storage account firewall setting
subnets := strings.Split(subnetName, ",")
for _, subnet := range subnets {
subnet = strings.TrimSpace(subnet)
vnetResourceID := d.getSubnetResourceID(vnetResourceGroup, vnetName, subnet)
klog.V(2).Infof("set vnetResourceID(%s) for NFS protocol", vnetResourceID)
vnetResourceIDs = append(vnetResourceIDs, vnetResourceID)
if err := d.updateSubnetServiceEndpoints(ctx, vnetResourceGroup, vnetName, subnet); err != nil {
return nil, status.Errorf(codes.Internal, "update service endpoints failed with error: %v", err)
}
var err error
if vnetResourceIDs, err = d.updateSubnetServiceEndpoints(ctx, vnetResourceGroup, vnetName, subnetName); err != nil {
return nil, status.Errorf(codes.Internal, "update service endpoints failed with error: %v", err)
}
}
}
Expand Down
Loading