Skip to content

Commit 678d5f6

Browse files
committed
Ensure COM threading apartment in SMB APIs
1 parent e3b1c4e commit 678d5f6

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

pkg/os/smb/api.go

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,26 @@ func New(requirePrivacy bool) *SmbAPI {
2828
}
2929

3030
func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) {
31+
var status int32
3132
err := cim.WithCOMThread(func() error {
3233
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath)
3334
if err != nil {
34-
return false, cim.IgnoreNotFound(err)
35+
return err
3536
}
3637

37-
status, err := inst.GetProperty("Status")
38+
statusProp, err := inst.GetProperty("Status")
3839
if err != nil {
39-
return false, err
40+
return err
4041
}
42+
43+
status = statusProp.(int32)
44+
return nil
4145
})
46+
if err != nil {
47+
return false, cim.IgnoreNotFound(err)
48+
}
4249

43-
return status.(int32) == cim.SmbMappingStatusOK, nil
50+
return status == cim.SmbMappingStatusOK, nil
4451
}
4552

4653
// NewSmbLink - creates a directory symbolic link to the remote share.
@@ -64,19 +71,23 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error {
6471
}
6572

6673
func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error {
67-
result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy)
68-
if err != nil {
69-
return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
70-
}
74+
return cim.WithCOMThread(func() error {
75+
result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy)
76+
if err != nil {
77+
return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
78+
}
7179

72-
return nil
80+
return nil
81+
})
7382
}
7483

7584
func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error {
76-
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath)
77-
if err != nil {
78-
return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
79-
}
85+
return cim.WithCOMThread(func() error {
86+
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath)
87+
if err != nil {
88+
return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
89+
}
8090

81-
return nil
91+
return nil
92+
})
8293
}

0 commit comments

Comments
 (0)