Skip to content

Commit 0d2f109

Browse files
committed
Build compiled methods to handle llvmcall
1 parent 8f55669 commit 0d2f109

File tree

4 files changed

+114
-5
lines changed

4 files changed

+114
-5
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.1.1"
44

55
[deps]
66
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7+
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
78

89
[extras]
910
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"

src/JuliaInterpreter.jl

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@ import Base: +, -, convert, isless
55
using Core: CodeInfo, SSAValue, SlotNumber, TypeMapEntry, SimpleVector, LineInfoNode, GotoNode, Slot,
66
GeneratedFunctionStub, MethodInstance, NewvarNode, TypeName
77

8+
using UUIDs
9+
810
export @enter, @make_stack, @interpret, Compiled, JuliaStackFrame
911

12+
module CompiledCalls
13+
# This module is for handling intrinsics that must be compiled (llvmcall)
14+
end
15+
1016
"""
1117
`Compiled` is a trait indicating that any `:call` expressions should be evaluated
1218
using Julia's normal compiled-code evaluation. The alternative is to pass `stack=JuliaStackFrame[]`,
@@ -50,7 +56,7 @@ Important fields:
5056
struct JuliaFrameCode
5157
scope::Union{Method,Module}
5258
code::CodeInfo
53-
methodtables::Vector{TypeMapEntry} # line-by-line method tables for generic-function :call Exprs
59+
methodtables::Vector{Union{Compiled,TypeMapEntry}} # line-by-line method tables for generic-function :call Exprs
5460
used::BitSet
5561
wrapper::Bool
5662
generator::Bool
@@ -63,10 +69,14 @@ function JuliaFrameCode(frame::JuliaFrameCode; wrapper = frame.wrapper, generato
6369
wrapper, generator, fullpath)
6470
end
6571

66-
function JuliaFrameCode(scope, code::CodeInfo; wrapper=false, generator=false, fullpath=true)
67-
code = optimize!(copy_codeinfo(code), moduleof(scope))
72+
function JuliaFrameCode(scope, code::CodeInfo; wrapper=false, generator=false, fullpath=true, optimize=true)
73+
if optimize
74+
code, methodtables = optimize!(copy_codeinfo(code), moduleof(scope))
75+
else
76+
code = copy_codeinfo(code)
77+
methodtables = Vector{Union{Compiled,TypeMapEntry}}(undef, length(code.code))
78+
end
6879
used = find_used(code)
69-
methodtables = Vector{TypeMapEntry}(undef, length(code.code))
7080
return JuliaFrameCode(scope, code, methodtables, used, wrapper, generator, fullpath)
7181
end
7282

@@ -612,6 +622,29 @@ function renumber_ssa!(stmts::Vector{Any}, ssalookup)
612622
return stmts
613623
end
614624

625+
# Pre-frame-construction lookup
626+
function lookup_stmt(stmts, arg)
627+
if isa(arg, SSAValue)
628+
arg = stmts[arg.id]
629+
end
630+
if isa(arg, QuoteNode)
631+
arg = arg.value
632+
end
633+
return arg
634+
end
635+
636+
function smallest_ref(stmts, arg, idmin)
637+
if isa(arg, SSAValue)
638+
idmin = min(idmin, arg.id)
639+
return smallest_ref(stmts, stmts[arg.id], idmin)
640+
elseif isa(arg, Expr)
641+
for a in arg.args
642+
idmin = smallest_ref(stmts, a, idmin)
643+
end
644+
end
645+
return idmin
646+
end
647+
615648
function lookup_global_refs!(ex::Expr)
616649
(ex.head == :isdefined || ex.head == :thunk || ex.head == :toplevel) && return nothing
617650
for (i, a) in enumerate(ex.args)
@@ -676,7 +709,48 @@ function optimize!(code::CodeInfo, mod::Module)
676709
ssalookup = cumsum(ssainc)
677710
renumber_ssa!(new_code, ssalookup)
678711
code.ssavaluetypes = length(new_code)
679-
return code
712+
713+
# Replace :llvmcall and :foreigncall with compiled variants. See
714+
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/13#issuecomment-464880123
715+
methodtables = Vector{Union{Compiled,TypeMapEntry}}(undef, length(code.code))
716+
for (idx, stmt) in enumerate(code.code)
717+
if isexpr(stmt, :call)
718+
# Check for :llvmcall
719+
arg1 = stmt.args[1]
720+
if arg1 == :llvmcall || lookup_stmt(code.code, arg1) == Base.llvmcall
721+
uuid = uuid4()
722+
ustr = replace(string(uuid), '-'=>'_')
723+
methname = Symbol("llvmcall_", ustr)
724+
nargs = length(stmt.args)-4
725+
argnames = [Symbol("arg", string(i)) for i = 1:nargs]
726+
# Run a mini-interpreter to extract the types
727+
framecode = JuliaFrameCode(CompiledCalls, code; optimize=false)
728+
frame = prepare_locals(framecode, [])
729+
idxstart = idx
730+
for i = 2:4
731+
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
732+
end
733+
frame.pc[] = JuliaProgramCounter(idxstart)
734+
while true
735+
pc = step_expr!(Compiled(), frame)
736+
convert(Int, pc) == idx && break
737+
pc === nothing && error("this should never happen")
738+
end
739+
str, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])
740+
def = quote
741+
function $methname($(argnames...))
742+
return Base.llvmcall($str, $RetType, $ArgType, $(argnames...))
743+
end
744+
end
745+
f = Core.eval(CompiledCalls, def)
746+
stmt.args[1] = QuoteNode(f)
747+
deleteat!(stmt.args, 2:4)
748+
methodtables[idx] = Compiled()
749+
end
750+
end
751+
end
752+
753+
return code, methodtables
680754
end
681755

682756
function prepare_locals(framecode, argvals::Vector{Any})

src/interpret.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,15 @@ function evaluate_call!(::Compiled, frame::JuliaStackFrame, call_expr::Expr, pc;
177177
end
178178

179179
function evaluate_call!(stack, frame::JuliaStackFrame, call_expr::Expr, pc; exec!::Function=finish_and_return!)
180+
idx = convert(Int, pc)
181+
if isassigned(frame.code.methodtables, idx)
182+
tme = frame.code.methodtables[idx]
183+
if isa(tme, Compiled)
184+
fargs = collect_args(frame, call_expr)
185+
f = to_function(fargs[1])
186+
return f(fargs[2:end]...)
187+
end
188+
end
180189
ret = maybe_evaluate_builtin(frame, call_expr)
181190
isa(ret, Some{Any}) && return ret.value
182191
fargs = collect_args(frame, call_expr)

test/interpret.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,28 @@ if isdefined(Core.Compiler, :SNCA)
226226
cfg = Core.Compiler.compute_basic_blocks(ci.code)
227227
@test isa(@interpret(Core.Compiler.SNCA(cfg)), Vector{Int})
228228
end
229+
230+
# llvmcall
231+
function add1234(x::Tuple{Int32,Int32,Int32,Int32})
232+
Base.llvmcall("""%3 = extractvalue [4 x i32] %0, 0
233+
%4 = extractvalue [4 x i32] %0, 1
234+
%5 = extractvalue [4 x i32] %0, 2
235+
%6 = extractvalue [4 x i32] %0, 3
236+
%7 = extractvalue [4 x i32] %1, 0
237+
%8 = extractvalue [4 x i32] %1, 1
238+
%9 = extractvalue [4 x i32] %1, 2
239+
%10 = extractvalue [4 x i32] %1, 3
240+
%11 = add i32 %3, %7
241+
%12 = add i32 %4, %8
242+
%13 = add i32 %5, %9
243+
%14 = add i32 %6, %10
244+
%15 = insertvalue [4 x i32] undef, i32 %11, 0
245+
%16 = insertvalue [4 x i32] %15, i32 %12, 1
246+
%17 = insertvalue [4 x i32] %16, i32 %13, 2
247+
%18 = insertvalue [4 x i32] %17, i32 %14, 3
248+
ret [4 x i32] %18""",Tuple{Int32,Int32,Int32,Int32},
249+
Tuple{Tuple{Int32,Int32,Int32,Int32},Tuple{Int32,Int32,Int32,Int32}},
250+
(Int32(1),Int32(2),Int32(3),Int32(4)),
251+
x)
252+
end
253+
@test @interpret(add1234(map(Int32,(2,3,4,5)))) === map(Int32,(3,5,7,9))

0 commit comments

Comments
 (0)