Skip to content

Commit 6eb988a

Browse files
committed
Merge pull request #54 from quinnj/aggregates
2 parents 3e1d14d + ef1d1a9 commit 6eb988a

File tree

3 files changed

+217
-28
lines changed

3 files changed

+217
-28
lines changed

README.md

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ A Julia interface to the SQLite library and support for operations on DataFrames
7676

7777
`drop` is pretty self-explanatory. It's really just a convenience wrapper around `query` to execute a DROP TABLE command, while also calling "VACUUM" to clean out freed memory from the database.
7878

79-
* `registerfunc(db::SQLiteDB, nargs::Int, func::Function, isdeterm::Bool=true; name="")`
79+
* `register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractString=string(func), isdeterm::Bool=true)`
80+
* `register(db::SQLiteDB, init, step::Function, final::Function=identity; nargs::Int=-1, name::AbstractString=string(final), isdeterm::Bool=true)`
8081

81-
Register a function `func` (which takes `nargs` number of arguments) with the SQLite database connection `db`. If the keyword argument `name` is given the function is registered with that name, otherwise it is registered with the name of `func`. If the function is stochastic (e.g. uses a random number) `isdeterm` should be set to `false`, see SQLite's [function creation documentation](http://sqlite.org/c3ref/create_function.html) for more information.
82+
Register a scalar (first method) or aggregate (second method) function with a `SQLiteDB`.
8283

83-
* `@scalarfunc function`
84-
`@scalarfunc name function`
84+
* `@register db function`
8585

86-
Define a function which can then be passed to `registerfunc`. In the first usage the function name is infered from the function definition, in the second it is explicitly given as the first parameter. The second form is only recommended when it's use is absolutely necessary, see below.
86+
Automatically define then register `function` with a `SQLiteDB`.
8787

8888
* `sr"..."`
8989

@@ -188,45 +188,31 @@ The sr"..." currently escapes all special characters in a string but it may be c
188188

189189
##### Custom Scalar Functions
190190

191-
SQLite.jl also provides a way that you can implement your own [Scalar Functions](https://www.sqlite.org/lang_corefunc.html) (though [Aggregate Functions](https://www.sqlite.org/lang_aggfunc.html) are not currently supported). This is done using the `registerfunc` function and `@scalarfunc` macro.
191+
SQLite.jl also provides a way that you can implement your own [Scalar Functions](https://www.sqlite.org/lang_corefunc.html). This is done using the `register` function and macro.
192192

193-
`@scalarfunc` takes an optional function name and a function and defines a new function which can be passed to `registerfunc`. It can be used with block function syntax
193+
`@register` takes a `SQLiteDB` and a function. The function can be in block syntax
194194

195195
```julia
196-
julia> @scalarfunc function add3(x)
196+
julia> @register db function add3(x)
197197
x + 3
198198
end
199-
add3 (generic function with 1 method)
200-
201-
julia> @scalarfunc add5 function irrelevantfuncname(x)
202-
x + 5
203-
end
204-
add5 (generic function with 1 method)
205199
```
206200

207201
inline function syntax
208202

209203
```julia
210-
julia> @scalarfunc mult3(x) = 3 * x
211-
mult3 (generic function with 1 method)
212-
213-
julia> @scalarfunc mult5 anotherirrelevantname(x) = 5 * x
214-
mult5 (generic function with 1 method)
204+
julia> @register db mult3(x) = 3 * x
215205
```
216206

217-
and previously defined functions (note that name inference does not work with this method)
207+
and previously defined functions
218208

219209
```julia
220-
julia> @scalarfunc sin sin
221-
sin (generic function with 1 method)
222-
223-
julia> @scalarfunc subtract -
224-
subtract (generic function with 1 method)
210+
julia> @register db sin
225211
```
226212

227-
The function that is defined can then be passed to `registerfunc`. `registerfunc` takes three arguments; the database to which the function should be registered, the number of arguments that the function takes and the function itself. The function is registered to the database connection rather than the database itself so must be registered each time the database opens. Your function can not take more than 127 arguments unless it takes a variable number of arguments, if it does take a variable number of arguments then you must pass -1 as the second argument to `registerfunc`.
213+
The `register` function takes optional arguments; `nargs` which defaults to `-1`, `name` which defaults to the name of the function, `isdeterm` which defaults to `true`. In practice these rarely need to be used.
228214

229-
The `@scalarfunc` macro uses the `sqlreturn` function to return your function's return value to SQLite. By default, `sqlreturn` maps the returned value to a [native SQLite type](http://sqlite.org/c3ref/result_blob.html) or, failing that, serializes the julia value and stores it as a `BLOB`. To change this behaviour simply define a new method for `sqlreturn` which then calls a previously defined method for `sqlreturn`. Methods which map to native SQLite types are
215+
The `register` function uses the `sqlreturn` function to return your function's return value to SQLite. By default, `sqlreturn` maps the returned value to a [native SQLite type](http://sqlite.org/c3ref/result_blob.html) or, failing that, serializes the julia value and stores it as a `BLOB`. To change this behaviour simply define a new method for `sqlreturn` which then calls a previously defined method for `sqlreturn`. Methods which map to native SQLite types are
230216

231217
```julia
232218
sqlreturn(context, ::NullType)
@@ -251,3 +237,23 @@ sqlreturn(context, val::Bool) = sqlreturn(context, int(val))
251237
```
252238

253239
Any new method defined for `sqlreturn` must take two arguments and must pass the first argument straight through as the first argument.
240+
241+
#### Custom Aggregate Functions
242+
243+
Using the `register` function, you can also define your own aggregate functions with largely the same semantics.
244+
245+
The `register` function for aggregates must take a `SQLiteDB`, an initial value, a step function and a final function. The first argument to the step function will be the return value of the previous function (or the initial value if it is the first iteration). The final function must take a single argument which will be the return value of the last step function.
246+
247+
```julia
248+
julia> dsum(prev, cur) = prev + cur
249+
250+
julia> dsum(prev) = 2 * prev
251+
252+
julia> register(db, 0, dsum, dsum)
253+
```
254+
255+
If no name is given the name of the first (step) function is used (in this case "dsum"). You can also use lambdas, the following does the same as the previous code snippet
256+
257+
```julia
258+
julia> register(db, 0, (p,c) -> p+c, p -> 2p, name="dsum")
259+
```

src/UDF.jl

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,138 @@ function scalarfunc(expr::Expr)
5454
f = eval(expr)
5555
return scalarfunc(f)
5656
end
57+
58+
# convert a byteptr to an int, assumes little-endian
59+
function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int)
60+
s = 0
61+
for i in start:start+len-1
62+
v = unsafe_load(ptr, i)
63+
s += v * 256^(i - start)
64+
end
65+
66+
# swap byte-order on big-endian machines
67+
# TODO: this desperately needs testing on a big-endian machine!!!!!
68+
return htol(s)
69+
end
70+
71+
function stepfunc(init, func, fsym=symbol(string(func)*"_step"))
72+
nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym
73+
return quote
74+
function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}})
75+
args = [sqlvalue(values, i) for i in 1:nargs]
76+
77+
intsize = sizeof(Int)
78+
ptrsize = sizeof(Ptr)
79+
acsize = intsize + ptrsize
80+
acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, acsize))
81+
82+
# acptr will be zeroed-out if this is the first iteration
83+
ret = ccall(
84+
:memcmp, Cint, (Ptr{UInt8}, Ptr{UInt8}, Cuint),
85+
zeros(UInt8, acsize), acptr, acsize,
86+
)
87+
if ret == 0
88+
acval = $(init)
89+
valsize = 256
90+
# avoid the garbage collector using malloc
91+
valptr = convert(Ptr{UInt8}, c_malloc(valsize))
92+
valptr == C_NULL && throw(SQLiteException("memory error"))
93+
else
94+
# size of serialized value is first sizeof(Int) bytes
95+
valsize = bytestoint(acptr, 1, intsize)
96+
# ptr to serialized value is last sizeof(Ptr) bytes
97+
valptr = reinterpret(
98+
Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize)
99+
)
100+
# deserialize the value pointed to by valptr
101+
acvalbuf = zeros(UInt8, valsize)
102+
unsafe_copy!(pointer(acvalbuf), valptr, valsize)
103+
acval = sqldeserialize(acvalbuf)
104+
end
105+
106+
local funcret
107+
try
108+
funcret = sqlserialize($(func)(acval, args...))
109+
catch
110+
c_free(valptr)
111+
rethrow()
112+
end
113+
114+
newsize = sizeof(funcret)
115+
if newsize > valsize
116+
# TODO: increase this in a cleverer way?
117+
tmp = convert(Ptr{UInt8}, c_realloc(valptr, newsize))
118+
if tmp == C_NULL
119+
c_free(valptr)
120+
throw(SQLiteException("memory error"))
121+
else
122+
valptr = tmp
123+
end
124+
end
125+
# copy serialized return value
126+
unsafe_copy!(valptr, pointer(funcret), newsize)
127+
128+
# copy the size of the serialized value
129+
unsafe_copy!(
130+
acptr,
131+
pointer(reinterpret(UInt8, [newsize])),
132+
intsize
133+
)
134+
# copy the address of the pointer to the serialized value
135+
valarr = reinterpret(UInt8, [valptr])
136+
for i in 1:length(valarr)
137+
unsafe_store!(acptr, valarr[i], intsize+i)
138+
end
139+
nothing
140+
end
141+
end
142+
end
143+
144+
function finalfunc(init, func, fsym=symbol(string(func)*"_final"))
145+
nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym
146+
return quote
147+
function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}})
148+
acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, 0))
149+
150+
# step function wasn't run
151+
if acptr == C_NULL
152+
sqlreturn(context, $(init))
153+
else
154+
intsize = sizeof(Int)
155+
ptrsize = sizeof(Ptr)
156+
acsize = intsize + ptrsize
157+
158+
# load size
159+
valsize = bytestoint(acptr, 1, intsize)
160+
# load ptr
161+
valptr = reinterpret(
162+
Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize)
163+
)
164+
165+
# load value
166+
acvalbuf = zeros(UInt8, valsize)
167+
unsafe_copy!(pointer(acvalbuf), valptr, valsize)
168+
acval = sqldeserialize(acvalbuf)
169+
170+
local ret
171+
try
172+
ret = $(func)(acval)
173+
finally
174+
c_free(valptr)
175+
end
176+
sqlreturn(context, ret)
177+
end
178+
nothing
179+
end
180+
end
181+
end
182+
57183
# User-facing macro for convenience in registering a simple function
58184
# with no configurations needed
59185
macro register(db, func)
60186
:(register($(esc(db)), $(esc(func))))
61187
end
188+
62189
# User-facing method for registering a Julia function to be used within SQLite
63190
function register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractString=string(func), isdeterm::Bool=true)
64191
@assert nargs <= 127 "use -1 if > 127 arguments are needed"
@@ -78,6 +205,28 @@ function register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractStr
78205
)
79206
end
80207

208+
# as above but for aggregate functions
209+
function register(
210+
db::SQLiteDB, init, step::Function, final::Function=identity;
211+
nargs::Int=-1, name::AbstractString=string(step), isdeterm::Bool=true
212+
)
213+
@assert nargs <= 127 "use -1 if > 127 arguments are needed"
214+
nargs < -1 && (nargs = -1)
215+
@assert sizeof(name) <= 255 "size of function name must be <= 255 chars"
216+
217+
s = eval(stepfunc(init, step, Base.function_name(step)))
218+
cs = cfunction(s, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}}))
219+
f = eval(finalfunc(init, final, Base.function_name(final)))
220+
cf = cfunction(f, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}}))
221+
222+
enc = SQLITE_UTF8
223+
enc = isdeterm ? enc | SQLITE_DETERMINISTIC : enc
224+
225+
@CHECK db sqlite3_create_function_v2(
226+
db.handle, name, nargs, enc, C_NULL, C_NULL, cs, cf, C_NULL
227+
)
228+
end
229+
81230
# annotate types because the MethodError makes more sense that way
82231
regexp(r::AbstractString, s::AbstractString) = ismatch(Regex(r), s)
83232
# macro for preserving the special characters in a string

test/runtests.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,41 @@ SQLite.@register db big
207207
r = query(db, "SELECT big(5)")
208208
@test r[1][1] == big(5)
209209

210+
doublesum_step(persist, current) = persist + current
211+
doublesum_final(persist) = 2 * persist
212+
register(db, 0, doublesum_step, doublesum_final, name="doublesum")
213+
r = query(db, "SELECT doublesum(UnitPrice) FROM Track")
214+
s = query(db, "SELECT UnitPrice FROM Track")
215+
@test_approx_eq r[1][1] 2*sum(s[1])
216+
217+
mycount(p, c) = p + 1
218+
register(db, 0, mycount)
219+
r = query(db, "SELECT mycount(TrackId) FROM PlaylistTrack")
220+
s = query(db, "SELECT count(TrackId) FROM PlaylistTrack")
221+
@test r[1] == s[1]
222+
223+
bigsum(p, c) = p + big(c)
224+
register(db, big(0), bigsum)
225+
r = query(db, "SELECT bigsum(TrackId) FROM PlaylistTrack")
226+
s = query(db, "SELECT TrackId FROM PlaylistTrack")
227+
@test r[1][1] == big(sum(s[1]))
228+
229+
query(db, "CREATE TABLE points (x INT, y INT, z INT)")
230+
query(db, "INSERT INTO points VALUES (?, ?, ?)", [1, 2, 3])
231+
query(db, "INSERT INTO points VALUES (?, ?, ?)", [4, 5, 6])
232+
query(db, "INSERT INTO points VALUES (?, ?, ?)", [7, 8, 9])
233+
type Point3D{T<:Number}
234+
x::T
235+
y::T
236+
z::T
237+
end
238+
==(a::Point3D, b::Point3D) = a.x == b.x && a.y == b.y && a.z == b.z
239+
+(a::Point3D, b::Point3D) = Point3D(a.x + b.x, a.y + b.y, a.z + b.z)
240+
sumpoint(p::Point3D, x, y, z) = p + Point3D(x, y, z)
241+
register(db, Point3D(0, 0, 0), sumpoint)
242+
r = query(db, "SELECT sumpoint(x, y, z) FROM points")
243+
@test r[1][1] == Point3D(12, 15, 18)
244+
drop(db, "points")
210245

211246
db2 = SQLiteDB()
212247
query(db2, "CREATE TABLE tab1 (r REAL, s INT)")
@@ -225,7 +260,6 @@ drop(db2, "tab2", ifexists=true)
225260

226261
close(db2)
227262

228-
229263
@test size(tables(db)) == (11,1)
230264

231265
close(db)

0 commit comments

Comments
 (0)