Skip to content

Commit 927a556

Browse files
committed
Merge branch 'quinnj-jq/scalarfuncs'
2 parents 72430bb + 0e420d3 commit 927a556

File tree

3 files changed

+63
-65
lines changed

3 files changed

+63
-65
lines changed

src/SQLite.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131
SQLiteDB(file,handle) = SQLiteDB(file,handle,0)
3232

3333
include("UDF.jl")
34-
export registerfunc, sqlreturn, @scalarfunc, @sr_str
34+
export @sr_str, @register, register
3535

3636

3737
function changes(db::SQLiteDB)
@@ -54,7 +54,7 @@ function SQLiteDB(file::AbstractString="";UTF16::Bool=false)
5454
file = isempty(file) ? file : expanduser(file)
5555
if @OK sqliteopen(utf(file),handle)
5656
db = SQLiteDB(utf(file),handle[1])
57-
registerfunc(db, 2, regexp)
57+
register(db, regexp, 2)
5858
finalizer(db,close)
5959
return db
6060
else # error

src/UDF.jl

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,3 @@
1-
# scalar functions
2-
function registerfunc(db::SQLiteDB, nargs::Integer, func::Function, isdeterm::Bool=true; name="")
3-
@assert nargs <= 127 "only varargs functions can have more than 127 arguments"
4-
# assume any negative number means a varargs function
5-
nargs < -1 && (nargs = -1)
6-
7-
name = isempty(name) ? string(func) : name::AbstractString
8-
@assert sizeof(name) <= 255 "size of function name must be <= 255"
9-
10-
cfunc = cfunction(func, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}}))
11-
12-
# TODO: allow the other encodings
13-
enc = SQLITE_UTF8
14-
enc = isdeterm ? enc | SQLITE_DETERMINISTIC : enc
15-
16-
@CHECK db sqlite3_create_function_v2(
17-
db.handle, name, nargs, enc, C_NULL, cfunc, C_NULL, C_NULL, C_NULL
18-
)
19-
end
20-
21-
# aggregate functions
22-
function registerfunc(db::SQLiteDB, nargs::Integer, step::Function, final::Function, isdeterm::Bool=true; name="")
23-
@assert nargs <= 127 "only varargs functions can have more than 127 arguments"
24-
# assume any negative number means a varargs function
25-
nargs < -1 && (nargs = -1)
26-
27-
name = isempty(name) ? string(step) : name::AbstractString
28-
cstep = cfunction(step, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}}))
29-
cfinal = cfunction(final, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}}))
30-
31-
# TODO: allow the other encodings
32-
enc = SQLITE_UTF8
33-
enc = isdeterm ? enc | SQLITE_DETERMINISTIC : enc
34-
35-
@CHECK db sqlite3_create_function_v2(
36-
db.handle, name, nargs, enc, C_NULL, C_NULL, cstep, cfinal, C_NULL
37-
)
38-
end
39-
401
function sqlvalue(values, i)
412
temp_val_ptr = unsafe_load(values, i)
423
valuetype = sqlite3_value_type(temp_val_ptr)
@@ -76,30 +37,55 @@ sqlreturn(context, val::Bool) = sqlreturn(context, int(val))
7637
sqludferror(context, msg::AbstractString) = sqlite3_result_error(context, msg)
7738
sqludferror(context, msg::UTF16String) = sqlite3_result_error16(context, msg)
7839

79-
function funcname(expr)
80-
if length(expr) == 2
81-
func = expr[2]
82-
name = expr[1]
83-
else
84-
func = expr[1]
85-
name = func.args[1].args[1]
86-
end
87-
name, func
88-
end
89-
90-
macro scalarfunc(args...)
91-
name, func = funcname(args)
40+
# Internal method for generating an SQLite scalar function from
41+
# a Julia function name
42+
function scalarfunc(func,fsym=symbol(string(func)))
43+
# check if name defined in Base so we don't clobber Base methods
44+
nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym
9245
return quote
93-
function $(esc(name))(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}})
94-
args = [sqlvalue(values, i) for i in 1:nargs]
46+
#nm needs to be a symbol or expr, i.e. :sin or :(Base.sin)
47+
function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}})
48+
args = [SQLite.sqlvalue(values, i) for i in 1:nargs]
9549
ret = $(func)(args...)
96-
sqlreturn(context, ret)
50+
SQLite.sqlreturn(context, ret)
9751
nothing
9852
end
9953
end
10054
end
55+
function scalarfunc(expr::Expr)
56+
f = eval(expr)
57+
return scalarfunc(f)
58+
end
59+
# User-facing macro for convenience in registering a simple function
60+
# with no configurations needed
61+
macro register(db, func)
62+
:(register($(esc(db)), $(esc(func))))
63+
end
64+
# User-facing method with keyword arguments for registering a function
65+
# to be used within SQLite
66+
function register(db::SQLiteDB, func::Function; nargs::Int=-1, isdeterm::Bool=true, name::String=string(func))
67+
register(db, func, nargs, isdeterm, name)
68+
end
69+
# User-facing method for registering a Julia function to be used within SQLite
70+
function register(db::SQLiteDB, func::Function, nargs::Int=-1, isdeterm::Bool=true, name::String=string(func))
71+
@assert nargs <= 127 "use -1 if > 127 arguments are needed"
72+
# assume any negative number means a varargs function
73+
nargs < -1 && (nargs = -1)
74+
@assert sizeof(name) <= 255 "size of function name must be <= 255"
75+
76+
f = eval(scalarfunc(func,symbol(name)))
77+
78+
cfunc = cfunction(f, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}}))
79+
# TODO: allow the other encodings
80+
enc = SQLITE_UTF8
81+
enc = isdeterm ? enc | SQLITE_DETERMINISTIC : enc
82+
83+
@CHECK db sqlite3_create_function_v2(
84+
db.handle, name, nargs, enc, C_NULL, cfunc, C_NULL, C_NULL, C_NULL
85+
)
86+
end
10187

10288
# annotate types because the MethodError makes more sense that way
103-
@scalarfunc regexp(r::AbstractString, s::AbstractString) = ismatch(Regex(r), s)
89+
regexp(r::String, s::String) = ismatch(Regex(r), s)
10490
# macro for preserving the special characters in a string
10591
macro sr_str(s) s end

test/runtests.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,25 +130,37 @@ end
130130
r = query(db, sr"SELECT LastName FROM Employee WHERE BirthDate REGEXP '^\d{4}-08'")
131131
@test r.values[1][1] == "Peacock"
132132

133-
@scalarfunc function triple(x)
134-
x * 3
135-
end
136-
@test_throws ErrorException registerfunc(db, 186, triple)
137-
registerfunc(db, 1, triple)
133+
triple(x) = x * 3
134+
@test_throws ErrorException SQLite.register(db, triple, 186)
135+
SQLite.register(db, triple, 1)
138136
r = query(db, "SELECT triple(Total) FROM Invoice ORDER BY InvoiceId LIMIT 5")
139137
s = query(db, "SELECT Total FROM Invoice ORDER BY InvoiceId LIMIT 5")
140138
for (i, j) in zip(r.values[1], s.values[1])
141139
@test_approx_eq i j*3
142140
end
143141

144-
@scalarfunc mult (*)
145-
registerfunc(db, -1, mult)
142+
SQLite.@register db function add4(q)
143+
q+4
144+
end
145+
r = query(db, "SELECT add4(AlbumId) FROM Album")
146+
s = query(db, "SELECT AlbumId FROM Album")
147+
@test r[1] == s[1]+4
148+
149+
SQLite.@register db mult(args...) = *(args...)
146150
r = query(db, "SELECT Milliseconds, Bytes FROM Track")
147151
s = query(db, "SELECT mult(Milliseconds, Bytes) FROM Track")
148152
@test r[1].*r[2] == s[1]
149153
t = query(db, "SELECT mult(Milliseconds, Bytes, 3, 4) FROM Track")
150154
@test r[1].*r[2]*3*4 == t[1]
151155

156+
SQLite.@register db sin
157+
u = query(db, "select sin(milliseconds) from track limit 5")
158+
@test all(-1 .< u[1] .< 1)
159+
160+
SQLite.register(db, hypot; nargs=2, name="hypotenuse")
161+
v = query(db, "select hypotenuse(Milliseconds,bytes) from track limit 5")
162+
@test [int(i) for i in v[1]] == [11175621,5521062,3997652,4339106,6301714]
163+
152164
@test size(tables(db)) == (11,1)
153165

154166
close(db)

0 commit comments

Comments
 (0)