-
Notifications
You must be signed in to change notification settings - Fork 157
[WIP] feat: support SMB mount with managed identity #2652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,8 +76,7 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu | |
mountPermissions := d.mountPermissions | ||
context := req.GetVolumeContext() | ||
if context != nil { | ||
// token request | ||
if getValueInMap(context, serviceAccountTokenField) != "" && getValueInMap(context, clientIDField) != "" { | ||
if !strings.EqualFold(getValueInMap(context, mountWithManagedIdentityField), trueValue) && getValueInMap(context, serviceAccountTokenField) != "" && getValueInMap(context, clientIDField) != "" { | ||
klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with service account token, clientID: %s", volumeID, target, getValueInMap(context, clientIDField)) | ||
_, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{ | ||
StagingTargetPath: target, | ||
|
@@ -248,7 +247,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe | |
volumeID := req.GetVolumeId() | ||
context := req.GetVolumeContext() | ||
|
||
if getValueInMap(context, clientIDField) != "" && getValueInMap(context, serviceAccountTokenField) == "" { | ||
if getValueInMap(context, clientIDField) != "" && !strings.EqualFold(getValueInMap(context, mountWithManagedIdentityField), trueValue) && getValueInMap(context, serviceAccountTokenField) == "" { | ||
klog.V(2).Infof("Skip NodeStageVolume for volume(%s) since clientID %s is provided but service account token is empty", volumeID, getValueInMap(context, clientIDField)) | ||
return &csi.NodeStageVolumeResponse{}, nil | ||
} | ||
|
@@ -277,9 +276,8 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe | |
} | ||
// don't respect fsType from req.GetVolumeCapability().GetMount().GetFsType() | ||
// since it's ext4 by default on Linux | ||
var fsType, server, protocol, ephemeralVolMountOptions, storageEndpointSuffix, folderName string | ||
var ephemeralVol bool | ||
var encryptInTransit bool | ||
var fsType, server, protocol, ephemeralVolMountOptions, storageEndpointSuffix, folderName, clientID string | ||
var ephemeralVol, encryptInTransit, mountWithManagedIdentity bool | ||
fileShareNameReplaceMap := map[string]string{} | ||
|
||
mountPermissions := d.mountPermissions | ||
|
@@ -313,7 +311,6 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe | |
fileShareNameReplaceMap[pvNameMetadata] = v | ||
case mountPermissionsField: | ||
if v != "" { | ||
var err error | ||
var perm uint64 | ||
if perm, err = strconv.ParseUint(v, 8, 32); err != nil { | ||
return nil, status.Errorf(codes.InvalidArgument, "invalid mountPermissions %s", v) | ||
|
@@ -325,11 +322,17 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe | |
} | ||
} | ||
case encryptInTransitField: | ||
var err error | ||
encryptInTransit, err = strconv.ParseBool(v) | ||
if err != nil { | ||
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("Volume context property %q must be a boolean value: %v", k, err)) | ||
} | ||
case mountWithManagedIdentityField: | ||
mountWithManagedIdentity, err = strconv.ParseBool(v) | ||
if err != nil { | ||
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("Volume context property %q must be a boolean value: %v", k, err)) | ||
} | ||
case clientIDField: | ||
clientID = v | ||
} | ||
} | ||
|
||
|
@@ -394,18 +397,29 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe | |
mountOptions = util.JoinMountOptions(mountFlags, []string{"vers=4,minorversion=1,sec=sys"}) | ||
mountOptions = appendDefaultNfsMountOptions(mountOptions, d.appendNoResvPortOption, d.appendActimeoOption) | ||
} else { | ||
if accountName == "" || accountKey == "" { | ||
return nil, status.Errorf(codes.Internal, "accountName(%s) or accountKey is empty", accountName) | ||
} | ||
if runtime.GOOS == "windows" { | ||
mountOptions = []string{fmt.Sprintf("AZURE\\%s", accountName)} | ||
sensitiveMountOptions = []string{accountKey} | ||
if mountWithManagedIdentity && runtime.GOOS != "windows" { | ||
if clientID == "" { | ||
clientID = d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID | ||
} | ||
sensitiveMountOptions = []string{"sec=krb5,cruid=0,upcall_target=mount", fmt.Sprintf("username=%s", clientID)} | ||
klog.V(2).Infof("using managed identity %s for volume %s with mount options: %v", clientID, volumeID, sensitiveMountOptions) | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] The nested if-else structure creates complex control flow. Consider restructuring to handle the different authentication methods (NFS, managed identity, traditional) in separate conditional blocks for better readability. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
if err := os.MkdirAll(targetPath, os.FileMode(mountPermissions)); err != nil { | ||
return nil, status.Error(codes.Internal, fmt.Sprintf("MkdirAll %s failed with error: %v", targetPath, err)) | ||
if accountName == "" || accountKey == "" { | ||
return nil, status.Errorf(codes.Internal, "accountName(%s) or accountKey is empty", accountName) | ||
} | ||
// parameters suggested by https://azure.microsoft.com/en-us/documentation/articles/storage-how-to-use-files-linux/ | ||
sensitiveMountOptions = []string{fmt.Sprintf("username=%s,password=%s", accountName, accountKey)} | ||
if runtime.GOOS == "windows" { | ||
mountOptions = []string{fmt.Sprintf("AZURE\\%s", accountName)} | ||
sensitiveMountOptions = []string{accountKey} | ||
} else { | ||
if err := os.MkdirAll(targetPath, os.FileMode(mountPermissions)); err != nil { | ||
return nil, status.Error(codes.Internal, fmt.Sprintf("MkdirAll %s failed with error: %v", targetPath, err)) | ||
} | ||
// parameters suggested by https://azure.microsoft.com/en-us/documentation/articles/storage-how-to-use-files-linux/ | ||
sensitiveMountOptions = []string{fmt.Sprintf("username=%s,password=%s", accountName, accountKey)} | ||
} | ||
} | ||
|
||
if runtime.GOOS != "windows" { | ||
if ephemeralVol { | ||
cifsMountFlags = util.JoinMountOptions(cifsMountFlags, strings.Split(ephemeralVolMountOptions, ",")) | ||
} | ||
|
@@ -449,6 +463,11 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe | |
klog.V(2).Infof("mount with proxy succeeded for %s", cifsMountPath) | ||
} else { | ||
execFunc := func() error { | ||
if mountWithManagedIdentity && protocol != nfs && runtime.GOOS != "windows" { | ||
if out, err := setCredentialCache(server, clientID); err != nil { | ||
return fmt.Errorf("setCredentialCache failed for %s with error: %v, output: %s", server, err, out) | ||
} | ||
} | ||
return SMBMount(d.mounter, source, cifsMountPath, mountFsType, mountOptions, sensitiveMountOptions) | ||
} | ||
timeoutFunc := func() error { return fmt.Errorf("time out") } | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ import ( | |
"context" | ||
"fmt" | ||
"os" | ||
"os/exec" | ||
"regexp" | ||
"strconv" | ||
"strings" | ||
|
@@ -372,3 +373,10 @@ func removeOptionIfExists(options []string, removeOption string) ([]string, bool | |
} | ||
return options, false | ||
} | ||
|
||
func setCredentialCache(server, clientID string) ([]byte, error) { | ||
cmd := exec.Command("azfilesauthmanager", "set", "https://"+server, "--imds-client-id", clientID) | ||
cmd.Env = append(os.Environ(), cmd.Env...) | ||
klog.V(2).Infof("Executing command: %q", cmd.String()) | ||
return cmd.CombinedOutput() | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The clientID parameter should be validated before being passed to exec.Command to prevent command injection attacks, especially since it could come from user-provided volume context. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
err
variable is used without being declared. This will cause a compilation error sinceerr
was removed from the variable declaration on the previous lines.Copilot uses AI. Check for mistakes.