Skip to content

Commit e3b1c4e

Browse files
committed
Ensure COM threading apartment for API calls
1 parent 82cf3f9 commit e3b1c4e

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

pkg/cim/wmi.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package cim
55

66
import (
77
"fmt"
8+
"runtime"
89

910
"github.com/go-ole/go-ole"
1011
"github.com/go-ole/go-ole/oleutil"
@@ -248,3 +249,30 @@ func IgnoreNotFound(err error) error {
248249
}
249250
return err
250251
}
252+
253+
// WithCOMThread runs the given function `fn` on a locked OS thread
254+
// with COM initialized using COINIT_MULTITHREADED.
255+
//
256+
// This is necessary for using COM/OLE APIs directly (e.g., via go-ole),
257+
// because COM requires that initialization and usage occur on the same thread.
258+
//
259+
// It performs the following steps:
260+
// - Locks the current goroutine to its OS thread
261+
// - Calls ole.CoInitializeEx with COINIT_MULTITHREADED
262+
// - Executes the user-provided function
263+
// - Uninitializes COM
264+
// - Unlocks the thread
265+
//
266+
// If COM initialization fails, or if the user's function returns an error,
267+
// that error is returned by WithCOMThread.
268+
func WithCOMThread(fn func() error) error {
269+
runtime.LockOSThread()
270+
defer runtime.UnlockOSThread()
271+
272+
if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
273+
return err
274+
}
275+
defer ole.CoUninitialize()
276+
277+
return fn()
278+
}

pkg/os/smb/api.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,17 @@ func New(requirePrivacy bool) *SmbAPI {
2828
}
2929

3030
func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) {
31-
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath)
32-
if err != nil {
33-
return false, cim.IgnoreNotFound(err)
34-
}
35-
36-
status, err := inst.GetProperty("Status")
37-
if err != nil {
38-
return false, err
39-
}
31+
err := cim.WithCOMThread(func() error {
32+
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath)
33+
if err != nil {
34+
return false, cim.IgnoreNotFound(err)
35+
}
36+
37+
status, err := inst.GetProperty("Status")
38+
if err != nil {
39+
return false, err
40+
}
41+
})
4042

4143
return status.(int32) == cim.SmbMappingStatusOK, nil
4244
}

0 commit comments

Comments
 (0)