Skip to content

Commit 7a00040

Browse files
authored
Avoid using cfunction closure (#336)
1 parent dce0b6c commit 7a00040

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

src/SQLite.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ mutable struct DB <: DBInterface.Connection
5858
file::String
5959
handle::DBHandle
6060
stmt_wrappers::WeakKeyDict{StmtWrapper,Nothing} # opened prepared statements
61-
registered_UDFs::Vector{Any} # keep registered UDFs alive and not garbage collected
61+
registered_UDF_data::Vector{Any} # keep registered UDFs alive and not garbage collected
6262

6363
function DB(f::AbstractString)
6464
handle_ptr = Ref{DBHandle}()

src/UDF.jl

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,25 @@ end
4444
sqlreturn(context, val::Bool) = sqlreturn(context, Int(val))
4545
sqlreturn(context, val) = sqlreturn(context, sqlserialize(val))
4646

47+
mutable struct ScalarUDFData
48+
func::Function
49+
end
50+
51+
mutable struct AggregateUDFData
52+
init::Any
53+
step::Function
54+
final::Function
55+
end
56+
4757
function wrap_scalarfunc(
48-
func,
4958
context::Ptr{Cvoid},
5059
nargs::Cint,
5160
values::Ptr{Ptr{Cvoid}},
5261
)
62+
udf_data =
63+
unsafe_pointer_to_objref(C.sqlite3_user_data(context))::ScalarUDFData
64+
func = udf_data.func
65+
5366
args = [sqlvalue(values, i) for i in 1:nargs]
5467
ret = func(args...)
5568
sqlreturn(context, ret)
@@ -70,12 +83,15 @@ function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int)
7083
end
7184

7285
function wrap_stepfunc(
73-
init,
74-
func,
7586
context::Ptr{Cvoid},
7687
nargs::Cint,
7788
values::Ptr{Ptr{Cvoid}},
7889
)
90+
udf_data =
91+
unsafe_pointer_to_objref(C.sqlite3_user_data(context))::AggregateUDFData
92+
init = udf_data.init
93+
func = udf_data.step
94+
7995
args = [sqlvalue(values, i) for i in 1:nargs]
8096

8197
intsize = sizeof(Int)
@@ -143,12 +159,15 @@ function wrap_stepfunc(
143159
end
144160

145161
function wrap_finalfunc(
146-
init,
147-
func,
148162
context::Ptr{Cvoid},
149163
nargs::Cint,
150164
values::Ptr{Ptr{Cvoid}},
151165
)
166+
udf_data =
167+
unsafe_pointer_to_objref(C.sqlite3_user_data(context))::AggregateUDFData
168+
init = udf_data.init
169+
func = udf_data.final
170+
152171
acptr = convert(Ptr{UInt8}, C.sqlite3_aggregate_context(context, 0))
153172

154173
# step function wasn't run
@@ -201,7 +220,7 @@ Register a scalar (first method) or aggregate (second method) function
201220
with a [`SQLite.DB`](@ref).
202221
"""
203222
function register(
204-
db,
223+
db::DB,
205224
func::Function;
206225
nargs::Int = -1,
207226
name::AbstractString = string(func),
@@ -212,11 +231,12 @@ function register(
212231
nargs < -1 && (nargs = -1)
213232
@assert sizeof(name) <= 255 "size of function name must be <= 255"
214233

215-
f =
216-
(context, nargs, values) ->
217-
wrap_scalarfunc(func, context, nargs, values)
218-
cfunc = @cfunction($f, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
219-
push!(db.registered_UDFs, cfunc)
234+
udf_data = ScalarUDFData(func)
235+
push!(db.registered_UDF_data, udf_data)
236+
udf_data_ptr = pointer_from_objref(udf_data)
237+
238+
cfunc =
239+
@cfunction(wrap_scalarfunc, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
220240

221241
# TODO: allow the other encodings
222242
enc = C.SQLITE_UTF8
@@ -227,7 +247,7 @@ function register(
227247
name,
228248
nargs,
229249
enc,
230-
C_NULL,
250+
udf_data_ptr,
231251
cfunc,
232252
C_NULL,
233253
C_NULL,
@@ -237,7 +257,7 @@ end
237257

238258
# as above but for aggregate functions
239259
function register(
240-
db,
260+
db::DB,
241261
init,
242262
step::Function,
243263
final::Function = identity;
@@ -249,16 +269,12 @@ function register(
249269
nargs < -1 && (nargs = -1)
250270
@assert sizeof(name) <= 255 "size of function name must be <= 255 chars"
251271

252-
s =
253-
(context, nargs, values) ->
254-
wrap_stepfunc(init, step, context, nargs, values)
255-
cs = @cfunction($s, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
256-
f =
257-
(context, nargs, values) ->
258-
wrap_finalfunc(init, final, context, nargs, values)
259-
cf = @cfunction($f, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
260-
push!(db.registered_UDFs, cs)
261-
push!(db.registered_UDFs, cf)
272+
udf_data = AggregateUDFData(init, step, final)
273+
push!(db.registered_UDF_data, udf_data)
274+
udf_data_ptr = pointer_from_objref(udf_data)
275+
276+
cs = @cfunction(wrap_stepfunc, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
277+
cf = @cfunction(wrap_finalfunc, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
262278

263279
enc = C.SQLITE_UTF8
264280
enc = isdeterm ? enc | C.SQLITE_DETERMINISTIC : enc
@@ -268,7 +284,7 @@ function register(
268284
name,
269285
nargs,
270286
enc,
271-
C_NULL,
287+
udf_data_ptr,
272288
C_NULL,
273289
cs,
274290
cf,

0 commit comments

Comments
 (0)