Skip to content

Commit 4702efa

Browse files
authored
Merge pull request #2609 from laozc/test-path-windows-call
feat: Migrate PathValid and SMB API to use WMI and Win32 API
2 parents 4f479e0 + 393d7f4 commit 4702efa

File tree

7 files changed

+158
-69
lines changed

7 files changed

+158
-69
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ require (
2424
github.com/microsoft/wmi v0.31.1
2525
github.com/onsi/ginkgo/v2 v2.23.4
2626
github.com/onsi/gomega v1.37.0
27+
github.com/pkg/errors v0.9.1
2728
github.com/rubiojr/go-vhd v0.0.0-20200706105327-02e210299021
2829
github.com/stretchr/testify v1.10.0
2930
go.uber.org/goleak v1.3.0
@@ -121,7 +122,6 @@ require (
121122
github.com/opencontainers/runtime-spec v1.2.0 // indirect
122123
github.com/opencontainers/selinux v1.11.1 // indirect
123124
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
124-
github.com/pkg/errors v0.9.1 // indirect
125125
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
126126
github.com/prometheus/client_golang v1.22.0 // indirect
127127
github.com/prometheus/client_model v0.6.2 // indirect

pkg/azurefile/azurefile_options.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func (o *DriverOptions) AddFlags() *flag.FlagSet {
8686
fs.BoolVar(&o.AppendNoShareSockOption, "append-nosharesock-option", true, "Whether appending nosharesock option to smb mount command")
8787
fs.BoolVar(&o.AppendNoResvPortOption, "append-noresvport-option", true, "Whether appending noresvport option to nfs mount command")
8888
fs.BoolVar(&o.AppendActimeoOption, "append-actimeo-option", true, "Whether appending actimeo=0 option to nfs mount command")
89-
fs.BoolVar(&o.UseWinCIMAPI, "use-win-cim-api", false, "Whether performing azure file operations using CIM API or Powershell command on Windows node")
89+
fs.BoolVar(&o.UseWinCIMAPI, "use-win-cim-api", true, "Whether performing azure file operations using CIM API or Powershell command on Windows node")
9090
fs.IntVar(&o.SkipMatchingTagCacheExpireInMinutes, "skip-matching-tag-cache-expire-in-minutes", 30, "The cache expire time in minutes for skipMatchingTagCache")
9191
fs.IntVar(&o.VolStatsCacheExpireInMinutes, "vol-stats-cache-expire-in-minutes", 10, "The cache expire time in minutes for volume stats cache")
9292
fs.BoolVar(&o.PrintVolumeStatsCallLogs, "print-volume-stats-call-logs", false, "Whether to print volume statfs call logs with log level 2")

pkg/mounter/safe_mounter_host_process_windows.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ func (mounter *winMounter) SMBMount(source, target, fsType string, mountOptions,
8787

8888
isMapped, err := mounter.smbAPI.IsSmbMapped(remotePath)
8989
if err != nil {
90+
klog.Errorf("IsSmbMapped(%s) failed with %v", remotePath, err)
9091
isMapped = false
9192
}
9293

pkg/os/cim/smb.go

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ limitations under the License.
2020
package cim
2121

2222
import (
23+
"strings"
24+
2325
"github.com/microsoft/wmi/pkg/base/query"
2426
cim "github.com/microsoft/wmi/pkg/wmiinstance"
2527
)
@@ -33,8 +35,24 @@ const (
3335
SmbMappingStatusConnecting
3436
SmbMappingStatusReconnecting
3537
SmbMappingStatusUnavailable
38+
39+
credentialDelimiter = ":"
3640
)
3741

42+
// escapeQueryParameter escapes a parameter for WMI Queries
43+
func escapeQueryParameter(s string) string {
44+
s = strings.ReplaceAll(s, "'", "''")
45+
s = strings.ReplaceAll(s, "\\", "\\\\")
46+
return s
47+
}
48+
49+
func escapeUserName(userName string) string {
50+
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170
51+
userName = strings.ReplaceAll(userName, "\\", "\\\\")
52+
userName = strings.ReplaceAll(userName, credentialDelimiter, "\\"+credentialDelimiter)
53+
return userName
54+
}
55+
3856
// QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path.
3957
//
4058
// The equivalent WMI query is:
@@ -44,7 +62,7 @@ const (
4462
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
4563
// for the WMI class definition.
4664
func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) {
47-
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
65+
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
4866
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
4967
if err != nil {
5068
return nil, err
@@ -53,12 +71,22 @@ func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, err
5371
return instances[0], err
5472
}
5573

56-
// RemoveSmbGlobalMappingByRemotePath removes a SMB global mapping matching to the remote path.
74+
// GetSmbGlobalMappingStatus returns the status of an SMB global mapping.
75+
func GetSmbGlobalMappingStatus(inst *cim.WmiInstance) (int32, error) {
76+
statusProp, err := inst.GetProperty("Status")
77+
if err != nil {
78+
return SmbMappingStatusUnavailable, err
79+
}
80+
81+
return statusProp.(int32), nil
82+
}
83+
84+
// RemoveSmbGlobalMappingByRemotePath removes an SMB global mapping matching to the remote path.
5785
//
5886
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
5987
// for the WMI class definition.
6088
func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
61-
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath)
89+
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
6290
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
6391
if err != nil {
6492
return err
@@ -67,3 +95,22 @@ func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
6795
_, err = instances[0].InvokeMethod("Remove", true)
6896
return err
6997
}
98+
99+
// NewSmbGlobalMapping creates a new SMB global mapping to the remote path.
100+
//
101+
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
102+
// for the WMI class definition.
103+
func NewSmbGlobalMapping(remotePath, username, password string, requirePrivacy bool) (int, error) {
104+
params := map[string]interface{}{
105+
"RemotePath": remotePath,
106+
"RequirePrivacy": requirePrivacy,
107+
}
108+
if username != "" {
109+
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178
110+
// on how SMB credential is handled in PowerShell
111+
params["Credential"] = escapeUserName(username) + credentialDelimiter + password
112+
}
113+
114+
result, _, err := InvokeCimMethod(WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params)
115+
return result, err
116+
}

pkg/os/cim/wmi.go

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,32 @@ limitations under the License.
2020
package cim
2121

2222
import (
23+
"errors"
2324
"fmt"
25+
"runtime"
2426

2527
"github.com/go-ole/go-ole"
2628
"github.com/go-ole/go-ole/oleutil"
2729
"github.com/microsoft/wmi/pkg/base/query"
28-
"github.com/microsoft/wmi/pkg/errors"
30+
wmierrors "github.com/microsoft/wmi/pkg/errors"
2931
cim "github.com/microsoft/wmi/pkg/wmiinstance"
32+
"golang.org/x/sys/windows"
3033
"k8s.io/klog/v2"
3134
)
3235

3336
const (
34-
WMINamespaceRoot = "Root\\CimV2"
37+
WMINamespaceCimV2 = "Root\\CimV2"
3538
WMINamespaceStorage = "Root\\Microsoft\\Windows\\Storage"
3639
WMINamespaceSmb = "Root\\Microsoft\\Windows\\Smb"
3740
)
3841

3942
type InstanceHandler func(instance *cim.WmiInstance) (bool, error)
4043

41-
// An InstanceIndexer provides index key to a WMI Instance in a map
42-
type InstanceIndexer func(instance *cim.WmiInstance) (string, error)
43-
4444
// NewWMISession creates a new local WMI session for the given namespace, defaulting
4545
// to root namespace if none specified.
4646
func NewWMISession(namespace string) (*cim.WmiSession, error) {
4747
if namespace == "" {
48-
namespace = WMINamespaceRoot
48+
namespace = WMINamespaceCimV2
4949
}
5050

5151
sessionManager := cim.NewWmiSessionManager()
@@ -80,7 +80,7 @@ func QueryFromWMI(namespace string, query *query.WmiQuery, handler InstanceHandl
8080
}
8181

8282
if len(instances) == 0 {
83-
return errors.NotFound
83+
return wmierrors.NotFound
8484
}
8585

8686
var cont bool
@@ -114,7 +114,7 @@ func executeClassMethodParam(classInst *cim.WmiInstance, method *cim.WmiMethod,
114114

115115
iDispatchInstance := classInst.GetIDispatch()
116116
if iDispatchInstance == nil {
117-
return nil, errors.Wrapf(errors.InvalidInput, "InvalidInstance")
117+
return nil, wmierrors.Wrapf(wmierrors.InvalidInput, "InvalidInstance")
118118
}
119119
rawResult, err := iDispatchInstance.GetProperty("Methods_")
120120
if err != nil {
@@ -254,11 +254,52 @@ func InvokeCimMethod(namespace, class, methodName string, inputParameters map[st
254254
return int(result.ReturnValue), outputParameters, nil
255255
}
256256

257+
// IsNotFound returns true if it's a "not found" error.
258+
func IsNotFound(err error) bool {
259+
return wmierrors.IsNotFound(err)
260+
}
261+
257262
// IgnoreNotFound returns nil if the error is nil or a "not found" error,
258263
// otherwise returns the original error.
259264
func IgnoreNotFound(err error) error {
260-
if err == nil || errors.IsNotFound(err) {
265+
if err == nil || IsNotFound(err) {
261266
return nil
262267
}
263268
return err
264269
}
270+
271+
// WithCOMThread runs the given function `fn` on a locked OS thread
272+
// with COM initialized using COINIT_MULTITHREADED.
273+
//
274+
// This is necessary for using COM/OLE APIs directly (e.g., via go-ole),
275+
// because COM requires that initialization and usage occur on the same thread.
276+
//
277+
// It performs the following steps:
278+
// - Locks the current goroutine to its OS thread
279+
// - Calls ole.CoInitializeEx with COINIT_MULTITHREADED
280+
// - Executes the user-provided function
281+
// - Uninitializes COM
282+
// - Unlocks the thread
283+
//
284+
// If COM initialization fails, or if the user's function returns an error,
285+
// that error is returned by WithCOMThread.
286+
func WithCOMThread(fn func() error) error {
287+
runtime.LockOSThread()
288+
defer runtime.UnlockOSThread()
289+
290+
if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
291+
var oleError *ole.OleError
292+
if errors.As(err, &oleError) && oleError != nil && oleError.Code() == uintptr(windows.S_FALSE) {
293+
klog.V(10).Infof("COM library has been already initialized for the calling thread, proceeding to the function with no error")
294+
err = nil
295+
}
296+
if err != nil {
297+
return err
298+
}
299+
} else {
300+
klog.V(10).Infof("COM library is initialized for the calling thread")
301+
}
302+
defer ole.CoUninitialize()
303+
304+
return fn()
305+
}

pkg/os/filesystem/filesystem.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import (
2626
"regexp"
2727
"strings"
2828

29+
"github.com/pkg/errors"
30+
"golang.org/x/sys/windows"
2931
"k8s.io/klog/v2"
3032
"sigs.k8s.io/azurefile-csi-driver/pkg/util"
3133
)
@@ -84,14 +86,27 @@ func PathExists(path string) (bool, error) {
8486
}
8587

8688
func PathValid(_ context.Context, path string) (bool, error) {
87-
cmd := `Test-Path $Env:remotepath`
88-
cmdEnv := fmt.Sprintf("remotepath=%s", path)
89-
output, err := util.RunPowershellCmd(cmd, cmdEnv)
89+
klog.V(6).Infof("PathValid called with path: %s", path)
90+
pathString, err := windows.UTF16PtrFromString(path)
9091
if err != nil {
91-
return false, fmt.Errorf("returned output: %s, error: %v", string(output), err)
92+
klog.V(6).Infof("failed to convert path %s to UTF16: %v", path, err)
93+
return false, fmt.Errorf("invalid path: %w", err)
9294
}
9395

94-
return strings.HasPrefix(strings.ToLower(string(output)), "true"), nil
96+
attrs, err := windows.GetFileAttributes(pathString)
97+
if err != nil {
98+
klog.V(6).Infof("failed to get file attributes for path %s: %v", path, err)
99+
if errors.Is(err, windows.ERROR_PATH_NOT_FOUND) || errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_INVALID_NAME) {
100+
klog.Warningf("path %s does not exist or is invalid, error: %v", path, err)
101+
return false, nil
102+
}
103+
104+
// GetFileAttribute returns user or password incorrect for a disconnected SMB connection after the password is changed
105+
return false, fmt.Errorf("failed to get path %s attribute: %w", path, err)
106+
}
107+
108+
klog.V(6).Infof("GetFileAttributes for path %s returned attributes: %d", path, attrs)
109+
return attrs != windows.INVALID_FILE_ATTRIBUTES, nil
95110
}
96111

97112
func ValidatePathWindows(path string) error {

pkg/os/smb/smb_cim.go

Lines changed: 35 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,10 @@ limitations under the License.
2020
package smb
2121

2222
import (
23-
"fmt"
24-
"strings"
25-
23+
"k8s.io/klog/v2"
2624
"sigs.k8s.io/azurefile-csi-driver/pkg/os/cim"
2725
)
2826

29-
const (
30-
credentialDelimiter = ":"
31-
)
32-
33-
func remotePathForQuery(remotePath string) string {
34-
return strings.ReplaceAll(remotePath, "\\", "\\\\")
35-
}
36-
37-
func escapeUserName(userName string) string {
38-
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170
39-
escaped := strings.ReplaceAll(userName, "\\", "\\\\")
40-
escaped = strings.ReplaceAll(escaped, credentialDelimiter, "\\"+credentialDelimiter)
41-
return escaped
42-
}
43-
4427
var _ SMBAPI = &cimSMBAPI{}
4528

4629
type cimSMBAPI struct{}
@@ -50,42 +33,44 @@ func NewCimSMBAPI() *cimSMBAPI {
5033
}
5134

5235
func (*cimSMBAPI) IsSmbMapped(remotePath string) (bool, error) {
53-
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePathForQuery(remotePath))
54-
if err != nil {
55-
return false, cim.IgnoreNotFound(err)
56-
}
57-
58-
status, err := inst.GetProperty("Status")
59-
if err != nil {
60-
return false, err
61-
}
62-
63-
return status.(int32) == cim.SmbMappingStatusOK, nil
36+
var isMapped bool
37+
err := cim.WithCOMThread(func() error {
38+
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath)
39+
if err != nil {
40+
klog.V(6).Infof("error querying smb mapping for remote path %s. err: %v", remotePath, err)
41+
return err
42+
}
43+
44+
status, err := cim.GetSmbGlobalMappingStatus(inst)
45+
if err != nil {
46+
klog.V(6).Infof("error getting smb mapping status for remote path %s. err: %v", remotePath, err)
47+
return err
48+
}
49+
50+
isMapped = status == cim.SmbMappingStatusOK
51+
return nil
52+
})
53+
return isMapped, cim.IgnoreNotFound(err)
6454
}
6555

6656
func (*cimSMBAPI) NewSmbGlobalMapping(remotePath, username, password string) error {
67-
params := map[string]interface{}{
68-
"RemotePath": remotePath,
69-
"RequirePrivacy": true,
70-
}
71-
if username != "" {
72-
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178
73-
// on how SMB credential is handled in PowerShell
74-
params["Credential"] = escapeUserName(username) + credentialDelimiter + password
75-
}
76-
77-
result, _, err := cim.InvokeCimMethod(cim.WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params)
78-
if err != nil {
79-
return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
80-
}
81-
82-
return nil
57+
return cim.WithCOMThread(func() error {
58+
result, err := cim.NewSmbGlobalMapping(remotePath, username, password, true)
59+
if err != nil {
60+
klog.V(6).Infof("error creating smb mapping for remote path %s. result %d, err: %v", remotePath, result, err)
61+
return err
62+
}
63+
return nil
64+
})
8365
}
8466

8567
func (*cimSMBAPI) RemoveSmbGlobalMapping(remotePath string) error {
86-
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePathForQuery(remotePath))
87-
if err != nil {
88-
return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
89-
}
90-
return nil
68+
return cim.WithCOMThread(func() error {
69+
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath)
70+
if err != nil {
71+
klog.V(6).Infof("error removing smb mapping for remote path %s. err: %v", remotePath, err)
72+
return err
73+
}
74+
return nil
75+
})
9176
}

0 commit comments

Comments
 (0)