diff --git a/pkg/smb/nodeserver.go b/pkg/smb/nodeserver.go index c95ac59eba2..d57c1861554 100644 --- a/pkg/smb/nodeserver.go +++ b/pkg/smb/nodeserver.go @@ -200,6 +200,12 @@ func (d *Driver) NodeStageVolume(_ context.Context, req *csi.NodeStageVolumeRequ mountOptions = mountFlags if !gidPresent && volumeMountGroup != "" { mountOptions = append(mountOptions, fmt.Sprintf("gid=%s", volumeMountGroup)) + if !raiseGroupRWXInMountFlags(mountOptions, "file_mode") { + mountOptions = append(mountOptions, "file_mode=0774") + } + if !raiseGroupRWXInMountFlags(mountOptions, "dir_mode") { + mountOptions = append(mountOptions, "dir_mode=0775") + } } if domain != "" { mountOptions = append(mountOptions, fmt.Sprintf("%s=%s", domainField, domain)) @@ -608,3 +614,25 @@ func deleteKerberosCache(krb5CacheDirectory, volumeID string) error { return nil } + +// Raises RWX bits for group access in the mode arg. If mode is invalid, keep it unchanged. +func enableGroupRWX(mode string) string { + v, e := strconv.ParseInt(mode, 0, 0) + if e != nil || v < 0 { + return mode + } + return fmt.Sprintf("0%o", v|070) +} + +// Apply enableGroupRWX() to the option "flag=xyz" +func raiseGroupRWXInMountFlags(mountFlags []string, flag string) bool { + for i, mountFlag := range mountFlags { + mountFlagSplit := strings.Split(mountFlag, "=") + if len(mountFlagSplit) != 2 || mountFlagSplit[0] != flag { + continue + } + mountFlags[i] = fmt.Sprintf("%s=%s", flag, enableGroupRWX(mountFlagSplit[1])) + return true + } + return false +} diff --git a/pkg/smb/nodeserver_test.go b/pkg/smb/nodeserver_test.go index a6997c403ae..d8ca16ff9b6 100644 --- a/pkg/smb/nodeserver_test.go +++ b/pkg/smb/nodeserver_test.go @@ -57,6 +57,21 @@ func TestNodeStageVolume(t *testing.T) { Mount: &csi.VolumeCapability_MountVolume{}, }, } + mountGroupVolCap := csi.VolumeCapability{ + AccessType: &csi.VolumeCapability_Mount{ + Mount: &csi.VolumeCapability_MountVolume{ + VolumeMountGroup: "1000", + }, + }, + } + mountGroupWithModesVolCap := csi.VolumeCapability{ + AccessType: &csi.VolumeCapability_Mount{ + Mount: &csi.VolumeCapability_MountVolume{ + VolumeMountGroup: "1000", + MountFlags: []string{"file_mode=0111", "dir_mode=0111"}, + }, + }, + } errorMountSensSource := testutil.GetWorkDirPath("error_mount_sens_source", t) smbFile := testutil.GetWorkDirPath("smb.go", t) @@ -191,6 +206,30 @@ func TestNodeStageVolume(t *testing.T) { strings.Replace(testSource, "\\", "\\\\", -1), sourceTest, testSource, sourceTest), expectedErr: testutil.TestError{}, }, + { + desc: "[Success] Valid request with VolumeMountGroup", + req: csi.NodeStageVolumeRequest{VolumeId: "vol_1##", StagingTargetPath: sourceTest, + VolumeCapability: &mountGroupVolCap, + VolumeContext: volContext, + Secrets: secrets}, + skipOnWindows: true, + flakyWindowsErrorMessage: fmt.Sprintf("rpc error: code = Internal desc = volume(vol_1##) mount \"%s\" on %#v failed with "+ + "NewSmbGlobalMapping(%s, %s) failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.", + strings.Replace(testSource, "\\", "\\\\", -1), sourceTest, testSource, sourceTest), + expectedErr: testutil.TestError{}, + }, + { + desc: "[Success] Valid request with VolumeMountGroup and file/dir modes", + req: csi.NodeStageVolumeRequest{VolumeId: "vol_1##", StagingTargetPath: sourceTest, + VolumeCapability: &mountGroupWithModesVolCap, + VolumeContext: volContext, + Secrets: secrets}, + skipOnWindows: true, + flakyWindowsErrorMessage: fmt.Sprintf("rpc error: code = Internal desc = volume(vol_1##) mount \"%s\" on %#v failed with "+ + "NewSmbGlobalMapping(%s, %s) failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.", + strings.Replace(testSource, "\\", "\\\\", -1), sourceTest, testSource, sourceTest), + expectedErr: testutil.TestError{}, + }, } // Setup @@ -930,3 +969,92 @@ func TestNodePublishVolumeIdempotentMount(t *testing.T) { err = os.RemoveAll(targetTest) assert.NoError(t, err) } + +func TestEnableGroupRWX(t *testing.T) { + tests := []struct { + value string + expectedValue string + }{ + { + value: "qwerty", + expectedValue: "qwerty", + }, + { + value: "0111", + expectedValue: "0171", + }, + } + + for _, test := range tests { + mode := enableGroupRWX(test.value) + assert.Equal(t, test.expectedValue, mode) + } +} + +func TestRaiseGroupRWXInMountFlags(t *testing.T) { + tests := []struct { + mountFlags []string + flag string + expectedResult bool + mountFlagsUpdated bool + expectedMountFlags []string + }{ + { + mountFlags: []string{""}, + flag: "flag", + expectedResult: false, + }, + { + mountFlags: []string{"irrelevant"}, + flag: "flag", + expectedResult: false, + }, + { + mountFlags: []string{"key=val"}, + flag: "flag", + expectedResult: false, + }, + { + mountFlags: []string{"flag=key=val"}, + flag: "flag", + expectedResult: false, + }, + { + // This is important: if we return false here, the caller will append another flag=... + mountFlags: []string{"flag=invalid"}, + flag: "flag", + expectedResult: true, + }, + { + // Main case: raising group bits in the value + mountFlags: []string{"flag=0111"}, + flag: "flag", + expectedResult: true, + mountFlagsUpdated: true, + expectedMountFlags: []string{"flag=0171"}, + }, + } + + for _, test := range tests { + savedMountFlags := make([]string, len(test.mountFlags)) + copy(savedMountFlags, test.mountFlags) + + result := raiseGroupRWXInMountFlags(test.mountFlags, test.flag) + if result != test.expectedResult { + t.Errorf("raiseGroupRWXInMountFlags(%v, %s) returned %t (expected: %t)", + test.mountFlags, test.flag, result, test.expectedResult) + } + + if test.mountFlagsUpdated { + if !reflect.DeepEqual(test.expectedMountFlags, test.mountFlags) { + t.Errorf("raiseGroupRWXInMountFlags(%v, %s) did not update mountFlags (expected: %v)", + savedMountFlags, test.flag, test.expectedMountFlags) + } + } else { + if !reflect.DeepEqual(savedMountFlags, test.mountFlags) { + t.Errorf("raiseGroupRWXInMountFlags(%v, %s) updated mountFlags: %v", + savedMountFlags, test.flag, test.mountFlags) + } + } + } +}