Skip to content

Commit 99ceeb6

Browse files
committed
Fix TODOs.
Attempt to account for big-endianness. Don't call sqlserialize unnecessarily. Try to avoid memory-leaks.
1 parent 4a0bcff commit 99ceeb6

File tree

1 file changed

+44
-31
lines changed

1 file changed

+44
-31
lines changed

src/UDF.jl

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -55,56 +55,64 @@ function scalarfunc(expr::Expr)
5555
return scalarfunc(f)
5656
end
5757

58-
# convert a byteptr to an int, ptr[start] -> 256^0, ptr[start+1] -> 256^1...
59-
# TODO: this assumes little-endian
58+
# convert a byteptr to an int, assumes little-endian
6059
function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int)
6160
s = 0
6261
for i in start:start+len-1
6362
v = unsafe_load(ptr, i)
6463
s += v * 256^(i - start)
6564
end
6665

67-
return s
66+
# swap byte-order on big-endian machines
67+
# TODO: this desperately needs testing on a big-endian machine!!!!!
68+
return htol(s)
6869
end
6970

7071
function stepfunc(init, func, fsym=symbol(string(func)*"_step"))
7172
nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym
7273
return quote
7374
function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}})
7475
args = [sqlvalue(values, i) for i in 1:nargs]
76+
7577
intsize = sizeof(Int)
7678
ptrsize = sizeof(Ptr)
7779
acsize = intsize + ptrsize
7880
acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, acsize))
81+
7982
# acptr will be zeroed-out if this is the first iteration
8083
ret = ccall(
8184
:memcmp, Cint, (Ptr{UInt8}, Ptr{UInt8}, Cuint),
8285
zeros(UInt8, acsize), acptr, acsize,
8386
)
87+
if ret == 0
88+
acval = $(init)
89+
valsize = 256
90+
# avoid the garbage collector using malloc
91+
valptr = convert(Ptr{UInt8}, c_malloc(valsize))
92+
else
93+
# size of serialized value is first sizeof(Int) bytes
94+
valsize = bytestoint(acptr, 1, intsize)
95+
# ptr to serialized value is last sizeof(Ptr) bytes
96+
valptr = reinterpret(
97+
Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize)
98+
)
99+
# deserialize the value pointed to by valptr
100+
acvalbuf = zeros(UInt8, valsize)
101+
unsafe_copy!(pointer(acvalbuf), valptr, valsize)
102+
acval = sqldeserialize(acvalbuf)
103+
end
104+
84105
try
85-
if ret == 0
86-
acval = $(init)
87-
# TODO: allocate 256 byte
88-
valsize = sizeof(sqlserialize(acval))
89-
valptr = convert(Ptr{UInt8}, c_malloc(valsize))
90-
else
91-
# size of serialized value is first sizeof(Int) bytes
92-
valsize = bytestoint(acptr, 1, intsize)
93-
# ptr to serialized value is last sizeof(Ptr) bytes
94-
valptr = reinterpret(
95-
Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize)
96-
)
97-
# deserialize the value pointed to by valptr
98-
acvalbuf = zeros(UInt8, valsize)
99-
unsafe_copy!(pointer(acvalbuf), valptr, valsize)
100-
acval = sqldeserialize(acvalbuf)
101-
end
102106
funcret = sqlserialize($(func)(acval, args...))
103-
newsize = length(funcret)
104-
# TODO: increase this in a cleverer way?
105-
newsize > valsize && (valptr = convert(Ptr{UInt8}, c_realloc(valptr, newsize)))
107+
108+
newsize = sizeof(funcret)
109+
if newsize > valsize
110+
# TODO: increase this in a cleverer way?
111+
valptr = convert(Ptr{UInt8}, c_realloc(valptr, newsize))
112+
end
106113
# copy serialized return value
107114
unsafe_copy!(valptr, pointer(funcret), newsize)
115+
108116
# following copies are easier with arrays
109117
acptr = pointer_to_array(acptr, acsize, false)
110118
# copy the size of the serialized value
@@ -133,33 +141,38 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step"))
133141
end
134142
end
135143

136-
# TODO: free valptr on error
137144
function finalfunc(init, func, fsym=symbol(string(func)*"_final"))
138145
nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym
139146
return quote
140147
function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}})
141148
acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, 0))
149+
142150
# step function wasn't run
143151
if acptr == C_NULL
144152
sqlreturn(context, $(init))
145153
else
146154
intsize = sizeof(Int)
147155
ptrsize = sizeof(Ptr)
148156
acsize = intsize + ptrsize
157+
149158
# load size
150159
valsize = bytestoint(acptr, 1, intsize)
151160
# load ptr
152161
valptr = reinterpret(
153162
Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize)
154163
)
155-
# load value
156-
acvalbuf = zeros(UInt8, valsize)
157-
unsafe_copy!(pointer(acvalbuf), valptr, valsize)
158-
acval = sqldeserialize(acvalbuf)
159164

160-
ret = $(func)(acval)
161-
c_free(valptr)
162-
sqlreturn(context, ret)
165+
try
166+
# load value
167+
acvalbuf = zeros(UInt8, valsize)
168+
unsafe_copy!(pointer(acvalbuf), valptr, valsize)
169+
acval = sqldeserialize(acvalbuf)
170+
171+
ret = $(func)(acval)
172+
sqlreturn(context, ret)
173+
finally
174+
c_free(valptr)
175+
end
163176
end
164177
nothing
165178
end

0 commit comments

Comments
 (0)