Skip to content

Commit 57d9aeb

Browse files
committed
Merge pull request #268 from ianlancetaylor/handle
callback: use handles rather than passing Go pointers
2 parents 0cc1174 + 8c66b9c commit 57d9aeb

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

callback.go

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,75 @@ import (
2424
"fmt"
2525
"math"
2626
"reflect"
27+
"sync"
2728
"unsafe"
2829
)
2930

3031
//export callbackTrampoline
3132
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
3233
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
33-
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
34+
fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo)
3435
fi.Call(ctx, args)
3536
}
3637

3738
//export stepTrampoline
3839
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
3940
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
40-
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
41+
ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
4142
ai.Step(ctx, args)
4243
}
4344

4445
//export doneTrampoline
4546
func doneTrampoline(ctx *C.sqlite3_context) {
46-
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
47+
handle := uintptr(C.sqlite3_user_data(ctx))
48+
ai := lookupHandle(handle).(*aggInfo)
4749
ai.Done(ctx)
4850
}
4951

52+
// Use handles to avoid passing Go pointers to C.
53+
54+
type handleVal struct {
55+
db *SQLiteConn
56+
val interface{}
57+
}
58+
59+
var handleLock sync.Mutex
60+
var handleVals = make(map[uintptr]handleVal)
61+
var handleIndex uintptr = 100
62+
63+
func newHandle(db *SQLiteConn, v interface{}) uintptr {
64+
handleLock.Lock()
65+
defer handleLock.Unlock()
66+
i := handleIndex
67+
handleIndex++
68+
handleVals[i] = handleVal{db, v}
69+
return i
70+
}
71+
72+
func lookupHandle(handle uintptr) interface{} {
73+
handleLock.Lock()
74+
defer handleLock.Unlock()
75+
r, ok := handleVals[handle]
76+
if !ok {
77+
if handle >= 100 && handle < handleIndex {
78+
panic("deleted handle")
79+
} else {
80+
panic("invalid handle")
81+
}
82+
}
83+
return r.val
84+
}
85+
86+
func deleteHandles(db *SQLiteConn) {
87+
handleLock.Lock()
88+
defer handleLock.Unlock()
89+
for handle, val := range handleVals {
90+
if val.db == db {
91+
delete(handleVals, handle)
92+
}
93+
}
94+
}
95+
5096
// This is only here so that tests can refer to it.
5197
type callbackArgRaw C.sqlite3_value
5298

sqlite3.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
367367
if pure {
368368
opts |= C.SQLITE_DETERMINISTIC
369369
}
370-
rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(uintptr(unsafe.Pointer(&fi))), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
370+
rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(newHandle(c, &fi)), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
371371
if rv != C.SQLITE_OK {
372372
return c.lastError()
373373
}
@@ -492,7 +492,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool
492492
if pure {
493493
opts |= C.SQLITE_DETERMINISTIC
494494
}
495-
rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(uintptr(unsafe.Pointer(&ai))), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
495+
rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
496496
if rv != C.SQLITE_OK {
497497
return c.lastError()
498498
}
@@ -705,6 +705,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
705705

706706
// Close the connection.
707707
func (c *SQLiteConn) Close() error {
708+
deleteHandles(c)
708709
rv := C.sqlite3_close_v2(c.db)
709710
if rv != C.SQLITE_OK {
710711
return c.lastError()

0 commit comments

Comments
 (0)