Skip to content

Commit 7c6e8da

Browse files
committed
use global state
1 parent 607bfa6 commit 7c6e8da

File tree

7 files changed

+56
-81
lines changed

7 files changed

+56
-81
lines changed

convert.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func From(v any) Object {
6363
switch vv.Kind() {
6464
case reflect.Ptr:
6565
if vv.Elem().Type().Kind() == reflect.Struct {
66-
maps := getCurrentThreadData()
66+
maps := getGlobalData()
6767
if pyType, ok := maps.pyTypes[vv.Elem().Type()]; ok {
6868
wrapper := allocWrapper((*C.PyTypeObject)(unsafe.Pointer(pyType)), vv.Interface())
6969
return newObject((*C.PyObject)(unsafe.Pointer(wrapper)))
@@ -173,7 +173,7 @@ func ToValue(obj Object, v reflect.Value) bool {
173173
}
174174
}
175175
} else {
176-
maps := getCurrentThreadData()
176+
maps := getGlobalData()
177177
tyMeta := maps.typeMetas[obj.Type().Obj()]
178178
if tyMeta == nil {
179179
return false
@@ -192,7 +192,7 @@ func fromSlice(v reflect.Value) List {
192192
l := v.Len()
193193
list := newList(C.PyList_New(C.Py_ssize_t(l)))
194194
ty := v.Type().Elem()
195-
maps := getCurrentThreadData()
195+
maps := getGlobalData()
196196
pyType, ok := maps.pyTypes[ty]
197197
if !ok {
198198
for i := 0; i < l; i++ {
@@ -221,7 +221,7 @@ func fromMap(v reflect.Value) Dict {
221221

222222
func fromStruct(v reflect.Value) Object {
223223
ty := v.Type()
224-
maps := getCurrentThreadData()
224+
maps := getGlobalData()
225225
if typeObj, ok := maps.pyTypes[ty]; ok {
226226
ptr := reflect.New(ty)
227227
ptr.Elem().Set(v)

extension.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,21 @@ func allocWrapper(typ *C.PyTypeObject, obj any) *wrapperType {
6666
wrapper := (*wrapperType)(unsafe.Pointer(self))
6767
holder := new(objectHolder)
6868
holder.obj = obj
69-
maps := getCurrentThreadData()
69+
maps := getGlobalData()
7070
maps.holders.PushFront(holder)
7171
wrapper.goObj = holder.obj
7272
wrapper.holder = holder
7373
return wrapper
7474
}
7575

7676
func freeWrapper(wrapper *wrapperType) {
77-
maps := getCurrentThreadData()
77+
maps := getGlobalData()
7878
maps.holders.Remove(wrapper.holder)
7979
}
8080

8181
//export wrapperAlloc
8282
func wrapperAlloc(typ *C.PyTypeObject, size C.Py_ssize_t) *C.PyObject {
83-
maps := getCurrentThreadData()
83+
maps := getGlobalData()
8484
meta := maps.typeMetas[(*C.PyObject)(unsafe.Pointer(typ))]
8585
wrapper := allocWrapper(typ, reflect.New(meta.typ).Interface())
8686
if wrapper == nil {
@@ -99,7 +99,7 @@ func wrapperDealloc(self *C.PyObject) {
9999
//export wrapperInit
100100
func wrapperInit(self, args *C.PyObject) C.int {
101101
typ := (*C.PyObject)(self).ob_type
102-
maps := getCurrentThreadData()
102+
maps := getGlobalData()
103103
typeMeta := maps.typeMetas[(*C.PyObject)(unsafe.Pointer(typ))]
104104
if typeMeta.init == nil {
105105
return 0
@@ -112,7 +112,7 @@ func wrapperInit(self, args *C.PyObject) C.int {
112112

113113
//export getterMethod
114114
func getterMethod(self *C.PyObject, _closure unsafe.Pointer, methodId C.int) *C.PyObject {
115-
maps := getCurrentThreadData()
115+
maps := getGlobalData()
116116
typeMeta := maps.typeMetas[(*C.PyObject)(unsafe.Pointer(self.ob_type))]
117117
if typeMeta == nil {
118118
SetError(fmt.Errorf("type %v not registered", FromPy(self)))
@@ -157,7 +157,7 @@ func getterMethod(self *C.PyObject, _closure unsafe.Pointer, methodId C.int) *C.
157157

158158
//export setterMethod
159159
func setterMethod(self, value *C.PyObject, _closure unsafe.Pointer, methodId C.int) C.int {
160-
maps := getCurrentThreadData()
160+
maps := getGlobalData()
161161
typeMeta := maps.typeMetas[(*C.PyObject)(unsafe.Pointer(self.ob_type))]
162162
if typeMeta == nil {
163163
SetError(fmt.Errorf("type %v not registered", FromPy(self)))
@@ -238,7 +238,7 @@ func wrapperMethod(self, args *C.PyObject, methodId C.int) *C.PyObject {
238238
key = (*C.PyObject)(unsafe.Pointer(self.ob_type))
239239
}
240240

241-
maps := getCurrentThreadData()
241+
maps := getGlobalData()
242242
typeMeta, ok := maps.typeMetas[key]
243243
if !ok {
244244
SetError(fmt.Errorf("type %v not registered", FromPy(key)))
@@ -457,7 +457,7 @@ func (m Module) AddType(obj, init any, name, doc string) Object {
457457
}
458458

459459
// Check if type already registered
460-
maps := getCurrentThreadData()
460+
maps := getGlobalData()
461461
if pyType, ok := maps.pyTypes[ty]; ok {
462462
return newObject(pyType)
463463
}
@@ -569,7 +569,7 @@ func (m Module) AddType(obj, init any, name, doc string) Object {
569569
}
570570
// Recursively register struct types
571571
if fieldType.Kind() == reflect.Struct {
572-
maps := getCurrentThreadData()
572+
maps := getGlobalData()
573573
if _, ok := maps.pyTypes[fieldType]; !ok {
574574
// Generate a unique type name based on package path and type name
575575
nestedName := fieldType.Name()
@@ -600,7 +600,7 @@ func (m Module) AddMethod(name string, fn any, doc string) Func {
600600
name = goNameToPythonName(name)
601601
doc = name + doc
602602

603-
maps := getCurrentThreadData()
603+
maps := getGlobalData()
604604
meta, ok := maps.typeMetas[m.obj]
605605
if !ok {
606606
meta = &typeMeta{

function.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ func (f Func) callOneArg(arg Objecter) Object {
3636
}
3737

3838
func (f Func) CallObject(args Tuple) Object {
39-
defer getCurrentThreadData().decRefObjectsIfNeeded()
39+
defer getGlobalData().decRefObjectsIfNeeded()
4040
return newObject(C.PyObject_CallObject(f.obj, args.obj))
4141
}
4242

4343
func (f Func) CallObjectKw(args Tuple, kw KwArgs) Object {
44-
defer getCurrentThreadData().decRefObjectsIfNeeded()
44+
defer getGlobalData().decRefObjectsIfNeeded()
4545
// Convert keyword arguments to Python dict
4646
kwDict := MakeDict(nil)
4747
for k, v := range kw {
@@ -53,7 +53,7 @@ func (f Func) CallObjectKw(args Tuple, kw KwArgs) Object {
5353
func (f Func) Call(args ...any) Object {
5454
argsTuple, kwArgs := splitArgs(args...)
5555
if kwArgs == nil {
56-
defer getCurrentThreadData().decRefObjectsIfNeeded()
56+
defer getGlobalData().decRefObjectsIfNeeded()
5757
switch len(args) {
5858
case 0:
5959
return f.callNoArgs()

thread_local.go renamed to global_data.go

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ package gp
66
import "C"
77

88
import (
9-
"fmt"
109
"reflect"
11-
"runtime"
1210
"sync"
1311
)
1412

13+
// ----------------------------------------------------------------------------
14+
1515
type holderList struct {
1616
head *objectHolder
1717
}
@@ -35,6 +35,10 @@ func (l *holderList) Remove(holder *objectHolder) {
3535
}
3636
}
3737

38+
// ----------------------------------------------------------------------------
39+
40+
const maxPyObjects = 128
41+
3842
type decRefList struct {
3943
objects []*C.PyObject
4044
mu sync.Mutex
@@ -59,69 +63,52 @@ func (l *decRefList) decRefAll() {
5963
}
6064
}
6165

62-
type threadData struct {
66+
// ----------------------------------------------------------------------------
67+
68+
type globalData struct {
6369
typeMetas map[*C.PyObject]*typeMeta
6470
pyTypes map[reflect.Type]*C.PyObject
6571
holders holderList
6672
decRefList decRefList
73+
finished int32
6774
}
6875

69-
const maxPyObjects = 128
76+
var (
77+
global *globalData
78+
once sync.Once
79+
)
7080

71-
func (td *threadData) addPyObject(obj *C.PyObject) {
72-
td.decRefList.add(obj)
81+
func getGlobalData() *globalData {
82+
return global
7383
}
7484

75-
func (td *threadData) decRefObjectsIfNeeded() {
76-
if len(td.decRefList.objects) > maxPyObjects {
77-
td.decRefList.decRefAll()
85+
func (gd *globalData) addDecRef(obj *C.PyObject) {
86+
if gd.finished != 0 {
87+
return
7888
}
89+
gd.decRefList.add(obj)
7990
}
8091

81-
var (
82-
globalThreadData sync.Map // map[int64]*threadData
83-
)
84-
85-
func getCurrentThreadData() *threadData {
86-
id := getThreadID()
87-
return getThreadData(id)
88-
}
89-
90-
func getThreadData(gid int64) *threadData {
91-
id := getThreadID()
92-
maps, ok := globalThreadData.Load(id)
93-
if !ok {
94-
// if not exists, create new thread data
95-
maps = &threadData{
96-
typeMetas: make(map[*C.PyObject]*typeMeta),
97-
pyTypes: make(map[reflect.Type]*C.PyObject),
98-
}
99-
globalThreadData.Store(id, maps)
92+
func (gd *globalData) decRefObjectsIfNeeded() {
93+
if gd.finished != 0 {
94+
return
10095
}
101-
return maps.(*threadData)
96+
gd.decRefList.decRefAll()
10297
}
10398

104-
func initThreadLocal() {
105-
id := getThreadID()
106-
maps := &threadData{
99+
// ----------------------------------------------------------------------------
100+
101+
func initGlobal() {
102+
global = &globalData{
107103
typeMetas: make(map[*C.PyObject]*typeMeta),
108104
pyTypes: make(map[reflect.Type]*C.PyObject),
109105
}
110-
globalThreadData.Store(id, maps)
111106
}
112107

113-
func cleanupThreadLocal() {
114-
id := getThreadID()
115-
globalThreadData.Delete(id)
108+
func markFinished() {
109+
global.finished = 1
116110
}
117111

118-
func getThreadID() int64 {
119-
var buf [64]byte
120-
n := runtime.Stack(buf[:], false)
121-
id := int64(0)
122-
_, err := fmt.Sscanf(string(buf[:n]), "goroutine %d ", &id)
123-
if err != nil {
124-
panic(err)
125-
}
126-
return id
112+
func cleanupGlobal() {
113+
global = nil
127114
}

module_test.go

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ import (
66

77
func TestModuleImport(t *testing.T) {
88
setupTest(t)
9-
Initialize()
10-
defer Finalize()
11-
129
// Test importing a built-in module
1310
mathMod := ImportModule("math")
1411
if mathMod.Nil() {
@@ -29,9 +26,6 @@ func TestModuleImport(t *testing.T) {
2926

3027
func TestGetModule(t *testing.T) {
3128
setupTest(t)
32-
Initialize()
33-
defer Finalize()
34-
3529
// First import the module
3630
sysModule := ImportModule("sys")
3731
if sysModule.Nil() {
@@ -52,8 +46,6 @@ func TestGetModule(t *testing.T) {
5246

5347
func TestCreateModule(t *testing.T) {
5448
setupTest(t)
55-
Initialize()
56-
defer Finalize()
5749

5850
// Create a new module
5951
modName := "test_module"
@@ -84,8 +76,6 @@ func TestCreateModule(t *testing.T) {
8476

8577
func TestGetModuleDict(t *testing.T) {
8678
setupTest(t)
87-
Initialize()
88-
defer Finalize()
8979

9080
// Get the module dictionary
9181
moduleDict := GetModuleDict()

object.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
// the Python Object's DecRef method during garbage collection.
1616
type pyObject struct {
1717
obj *C.PyObject
18-
gid int64
1918
}
2019

2120
func (obj *pyObject) Obj() *PyObject {
@@ -55,14 +54,12 @@ func newObject(obj *PyObject) Object {
5554
C.PyErr_Print()
5655
panic("nil Python object")
5756
}
58-
o := &pyObject{obj: obj, gid: getThreadID()}
59-
p := Object{o}
60-
57+
o := &pyObject{obj: obj}
6158
runtime.SetFinalizer(o, func(o *pyObject) {
62-
maps := getThreadData(o.gid)
63-
maps.addPyObject(o.obj)
59+
getGlobalData().addDecRef(o.obj)
60+
runtime.SetFinalizer(o, nil)
6461
})
65-
return p
62+
return Object{o}
6663
}
6764

6865
func (obj Object) Dir() List {

python.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ type PyCFunction = C.PyCFunction
1717
func Initialize() {
1818
runtime.LockOSThread()
1919
C.Py_Initialize()
20-
initThreadLocal()
20+
initGlobal()
2121
}
2222

2323
func Finalize() {
24-
cleanupThreadLocal()
24+
markFinished()
2525
r := C.Py_FinalizeEx()
26+
cleanupGlobal()
2627
check(r == 0, "failed to finalize Python")
2728
}
2829

0 commit comments

Comments
 (0)