Skip to content

Commit c1b99a1

Browse files
authored
WIP: also build compiled calls for ccall (#216)
* in some cases, also build compiled calls for ccall * fixes * put back precompile statement * fix working with sparams * fix compiling ccalls when they are RHS of assignments * fix for TypeofBottom * fix calls to get pointer in ccall
1 parent 8758ba9 commit c1b99a1

File tree

6 files changed

+137
-30
lines changed

6 files changed

+137
-30
lines changed

src/commands.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,9 @@ function maybe_step_through_wrapper!(@nospecialize(recurse), frame::Frame)
249249
end
250250
end
251251
ret = evaluate_call!(dummy_breakpoint, frame, last)
252-
@assert isa(ret, BreakpointRef)
252+
if !isa(ret, BreakpointRef) # Happens if next call is Compiled
253+
return frame
254+
end
253255
frame.framedata.ssavalues[frame.pc] = Wrapper()
254256
return maybe_step_through_wrapper!(recurse, callee(frame))
255257
end

src/interpret.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,18 @@ function evaluate_foreigncall(frame::Frame, call_expr::Expr)
171171
return Core.eval(moduleof(frame), Expr(head, args...))
172172
end
173173

174-
# We have to intercept llvmcall before we try it as a builtin
174+
# We have to intercept ccalls / llvmcalls before we try it as a builtin
175175
function bypass_builtins(frame, call_expr, pc)
176176
if isassigned(frame.framecode.methodtables, pc)
177177
tme = frame.framecode.methodtables[pc]
178178
if isa(tme, Compiled)
179179
fargs = collect_args(frame, call_expr)
180180
f = to_function(fargs[1])
181-
return Some{Any}(f(fargs[2:end]...))
181+
if parentmodule(f) === JuliaInterpreter.CompiledCalls
182+
return Some{Any}(Base.invokelatest(f, fargs[2:end]...))
183+
else
184+
return Some{Any}(f(fargs[2:end]...))
185+
end
182186
end
183187
end
184188
return nothing

src/optimize.jl

Lines changed: 116 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,14 @@ Currently it looks up `GlobalRef`s (for which it needs `mod` to know the scope i
104104
which this will run) and ensures that no statement includes nested `:call` expressions
105105
(splitting them out into multiple SSA-form statements if needed).
106106
"""
107-
function optimize!(code::CodeInfo, mod::Module)
107+
function optimize!(code::CodeInfo, scope)
108+
mod = moduleof(scope)
109+
sparams = scope isa Method ? Symbol[sparam_syms(scope)...] : Symbol[]
108110
code.inferred && error("optimization of inferred code not implemented")
109111
# TODO: because of builtins.jl, for CodeInfos like
110112
# %1 = Core.apply_type
111113
# %2 = (%1)(args...)
112114
# it would be best to *not* resolve the GlobalRef at %1
113-
114115
## Replace GlobalRefs with QuoteNodes
115116
for (i, stmt) in enumerate(code.code)
116117
if isa(stmt, GlobalRef)
@@ -150,42 +151,131 @@ function optimize!(code::CodeInfo, mod::Module)
150151
# Replace :llvmcall and :foreigncall with compiled variants. See
151152
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/13#issuecomment-464880123
152153
methodtables = Vector{Union{Compiled,TypeMapEntry}}(undef, length(code.code))
154+
# @show code
153155
for (idx, stmt) in enumerate(code.code)
156+
# Foregincalls can be rhs of assignments
157+
if isexpr(stmt, :(=))
158+
stmt = stmt.args[2]
159+
end
154160
if isexpr(stmt, :call)
155161
# Check for :llvmcall
156162
arg1 = stmt.args[1]
157-
if arg1 == :llvmcall || lookup_stmt(code.code, arg1) == Base.llvmcall
163+
if (arg1 == :llvmcall || lookup_stmt(code.code, arg1) == Base.llvmcall) && isempty(sparams) && scope isa Method
158164
uuid = uuid4()
159165
ustr = replace(string(uuid), '-'=>'_')
160166
methname = Symbol("llvmcall_", ustr)
161167
nargs = length(stmt.args)-4
162-
argnames = [Symbol("arg", string(i)) for i = 1:nargs]
163-
# Run a mini-interpreter to extract the types
164-
framecode = FrameCode(CompiledCalls, code; optimize=false)
165-
frame = Frame(framecode, prepare_framedata(framecode, []))
166-
idxstart = idx
167-
for i = 2:4
168-
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
169-
end
170-
frame.pc = idxstart
171-
while true
172-
pc = step_expr!(Compiled(), frame)
173-
pc == idx && break
174-
pc === nothing && error("this should never happen")
175-
end
176-
str, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])
177-
def = quote
178-
function $methname($(argnames...))
179-
return Base.llvmcall($str, $RetType, $ArgType, $(argnames...))
180-
end
181-
end
182-
f = Core.eval(CompiledCalls, def)
183-
stmt.args[1] = QuoteNode(f)
184-
deleteat!(stmt.args, 2:4)
168+
build_compiled_call!(stmt, methname, Base.llvmcall, stmt.args[2:4], code, idx, nargs, sparams)
185169
methodtables[idx] = Compiled()
186170
end
171+
elseif isexpr(stmt, :foreigncall) && scope isa Method
172+
f = lookup_stmt(code.code, stmt.args[1])
173+
if isa(f, Ptr)
174+
f = string(uuid4())
175+
elseif isexpr(f, :call)
176+
length(f.args) == 3 || continue
177+
f.args[1] === tuple || continue
178+
lib = f.args[3] isa String ? f.args[3] : f.args[3].value
179+
prefix = f.args[2] isa String ? f.args[2] : f.args[2].value
180+
f = Symbol(prefix, '_', lib)
181+
end
182+
# Punt on non literal ccall arguments for now
183+
if !(isa(f, String) || isa(f, Symbol) || isa(f, Ptr))
184+
continue
185+
end
186+
# TODO: Only compile one ccall per call and argument types
187+
uuid = uuid4()
188+
ustr = replace(string(uuid), '-'=>'_')
189+
methname = Symbol("ccall", '_', f, '_', ustr)
190+
nargs = stmt.args[5]
191+
build_compiled_call!(stmt, methname, :ccall, stmt.args[1:3], code, idx, nargs, sparams)
192+
methodtables[idx] = Compiled()
187193
end
188194
end
189195

190196
return code, methodtables
191197
end
198+
199+
function parametric_type_to_expr(t::Type)
200+
t isa Core.TypeofBottom && return t
201+
return t.hasfreetypevars ? Expr(:curly, t.name.name, ((tv-> tv isa TypeVar ? tv.name : tv).(t.parameters))...) : t
202+
end
203+
204+
# Handle :llvmcall & :foreigncall (issue #28)
205+
function build_compiled_call!(stmt, methname, fcall, typargs, code, idx, nargs, sparams)
206+
argnames = Any[Symbol("arg", string(i)) for i = 1:nargs]
207+
if fcall == :ccall
208+
cfunc, RetType, ArgType = lookup_stmt(code.code, stmt.args[1]), stmt.args[2], stmt.args[3]
209+
# The result of this is useful to have next to you when reading this code:
210+
# f(x, y) = ccall(:jl_value_ptr, Ptr{Cvoid}, (Float32,Any), x, y)
211+
# @code_lowered f(2, 3)
212+
args = []
213+
for (atype, arg) in zip(ArgType, stmt.args[6:6+nargs-1])
214+
if atype === Any
215+
push!(args, arg)
216+
else
217+
@assert arg isa SSAValue
218+
unsafe_convert_expr = code.code[arg.id]
219+
cconvert_expr = code.code[unsafe_convert_expr.args[3].id]
220+
push!(args, cconvert_expr.args[3])
221+
end
222+
end
223+
else
224+
# Run a mini-interpreter to extract the types
225+
framecode = FrameCode(CompiledCalls, code; optimize=false)
226+
frame = Frame(framecode, prepare_framedata(framecode, []))
227+
idxstart = idx
228+
for i = 2:4
229+
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
230+
end
231+
frame.pc = idxstart
232+
if idxstart < idx
233+
while true
234+
pc = step_expr!(Compiled(), frame)
235+
pc == idx && break
236+
pc === nothing && error("this should never happen")
237+
end
238+
end
239+
cfunc, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])
240+
args = stmt.args[5:end]
241+
end
242+
if isa(cfunc, Expr)
243+
cfunc = eval(cfunc)
244+
end
245+
if isa(cfunc, Symbol)
246+
cfunc = QuoteNode(cfunc)
247+
end
248+
if fcall == :ccall
249+
ArgType = Expr(:tuple, [parametric_type_to_expr(t) for t in ArgType]...)
250+
end
251+
if isa(RetType, SimpleVector)
252+
@assert length(RetType) == 1
253+
RetType = RetType[1]
254+
end
255+
RetType = parametric_type_to_expr(RetType)
256+
wrapargs = copy(argnames)
257+
for sparam in sparams
258+
push!(wrapargs, :(::Type{$sparam}))
259+
end
260+
if stmt.args[4] == :(:llvmcall)
261+
def = :(
262+
function $methname($(wrapargs...)) where {$(sparams...)}
263+
return $fcall($cfunc, llvmcall, $RetType, $ArgType, $(argnames...))
264+
end)
265+
else
266+
def = :(
267+
function $methname($(wrapargs...)) where {$(sparams...)}
268+
return $fcall($cfunc, $RetType, $ArgType, $(argnames...))
269+
end)
270+
end
271+
f = Core.eval(CompiledCalls, def)
272+
stmt.args[1] = QuoteNode(f)
273+
stmt.head = :call
274+
deleteat!(stmt.args, 2:length(stmt.args))
275+
append!(stmt.args, args)
276+
for i in 1:length(sparams)
277+
push!(stmt.args, :($(Expr(:static_parameter, 1))))
278+
end
279+
return
280+
end
281+

src/precompile.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ function _precompile_()
3535
@assert precompile(Tuple{typeof(enter_call_expr), Expr})
3636
@assert precompile(Tuple{typeof(copy_codeinfo), Core.CodeInfo})
3737
@assert precompile(Tuple{typeof(optimize!), Core.CodeInfo, Module})
38+
@assert precompile(Tuple{typeof(optimize!), Core.CodeInfo, Method})
3839
@assert precompile(Tuple{typeof(set_structtype_const), Module, Symbol})
3940
@assert precompile(Tuple{typeof(namedtuple), Vector{Any}})
4041
@assert precompile(Tuple{typeof(resolvefc), Frame, Any})

src/types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ end
6767
const BREAKPOINT_EXPR = :($(QuoteNode(getproperty))($JuliaInterpreter, :__BREAKPOINT_MARKER__))
6868
function FrameCode(scope, src::CodeInfo; generator=false, optimize=true)
6969
if optimize
70-
src, methodtables = optimize!(copy_codeinfo(src), moduleof(scope))
70+
src, methodtables = optimize!(copy_codeinfo(src), scope)
7171
else
7272
src = copy_codeinfo(src)
7373
methodtables = Vector{Union{Compiled,TypeMapEntry}}(undef, length(src.code))

test/interpret.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,13 @@ function hash220(x::Tuple{Ptr{UInt8},Int}, h::UInt)
453453
ccall(Base.memhash, UInt, (Ptr{UInt8}, Csize_t, UInt32), x[1], x[2], h % UInt32) + h
454454
end
455455
@test @interpret(hash220((Ptr{UInt8}(0),0), UInt(1))) == hash220((Ptr{UInt8}(0),0), UInt(1))
456+
457+
# ccall with type parameters
458+
@test (@interpret Base.unsafe_convert(Ptr{Int}, [1,2])) isa Ptr{Int}
459+
460+
# ccall with call to get the pointer
461+
cf = [@cfunction(fcfun, Int, (Int, Int))]
462+
function call_cf()
463+
ccall(cf[1], Int, (Int, Int), 1, 2)
464+
end
465+
@test (@interpret call_cf()) == call_cf()

0 commit comments

Comments
 (0)