Skip to content

Commit 2c7ef57

Browse files
authored
Interpret most of Core.Compiler (#277)
* Define our own `SSAValue` and `SlotNumber` types to distinguish them from values generated in `Core.Compiler` * `eval` the compiled `ccalls` from `Core.Compiler` in that module * re-use previous compiled-`ccalls` based on arg types
1 parent f09fc63 commit 2c7ef57

File tree

7 files changed

+139
-42
lines changed

7 files changed

+139
-42
lines changed

src/JuliaInterpreter.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module JuliaInterpreter
22

33
using Base.Meta
44
import Base: +, -, convert, isless
5-
using Core: CodeInfo, SSAValue, SlotNumber, TypeMapEntry, SimpleVector, LineInfoNode, GotoNode, Slot,
5+
using Core: CodeInfo, TypeMapEntry, SimpleVector, LineInfoNode, GotoNode, Slot,
66
GeneratedFunctionStub, MethodInstance, NewvarNode, TypeName
77

88
using UUIDs
@@ -18,7 +18,7 @@ export @interpret, Compiled, Frame, root, leaf,
1818
debug_command, @bp, break_on, break_off
1919

2020
module CompiledCalls
21-
# This module is for handling intrinsics that must be compiled (llvmcall)
21+
# This module is for handling intrinsics that must be compiled (llvmcall) as well as ccalls
2222
end
2323

2424
# "Backport" of https://github.com/JuliaLang/julia/pull/31536
@@ -70,12 +70,43 @@ function set_compiled_methods()
7070
push!(compiled_methods, which(subtypes, Tuple{Module, Type}))
7171
push!(compiled_methods, which(subtypes, Tuple{Type}))
7272

73-
push!(compiled_modules, Core.Compiler)
73+
# Anything that ccalls jl_typeinf_begin cannot currently be handled
74+
for finf in (Core.Compiler.typeinf_code, Core.Compiler.typeinf_ext, Core.Compiler.typeinf_type)
75+
for m in methods(finf)
76+
push!(compiled_methods, m)
77+
end
78+
end
79+
7480
push!(compiled_modules, Base.Threads)
7581
end
7682

7783
function __init__()
7884
set_compiled_methods()
85+
# If we interpret into Core.Compiler, we need to take precautions to avoid needing
86+
# inference of JuliaInterpreter methods in the middle of a `ccall(:jl_typeinf_begin, ...)`
87+
# block.
88+
# for (sym, RT, AT) in ((:jl_typeinf_begin, Cvoid, ()),
89+
# (:jl_typeinf_end, Cvoid, ()),
90+
# (:jl_isa_compileable_sig, Int32, (Any, Any)),
91+
# (:jl_compress_ast, Any, (Any, Any)),
92+
# # (:jl_set_method_inferred, Ref{Core.CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt)),
93+
# (:jl_method_instance_add_backedge, Cvoid, (Any, Any)),
94+
# (:jl_method_table_add_backedge, Cvoid, (Any, Any, Any)),
95+
# (:jl_new_code_info_uninit, Ref{CodeInfo}, ()),
96+
# (:jl_uncompress_argnames, Vector{Symbol}, (Any,)),
97+
# (:jl_get_tls_world_age, UInt, ()),
98+
# (:jl_call_in_typeinf_world, Any, (Ptr{Ptr{Cvoid}}, Cint)),
99+
# (:jl_value_ptr, Any, (Ptr{Cvoid},)),
100+
# (:jl_value_ptr, Ptr{Cvoid}, (Any,)))
101+
# fname = Symbol(:ccall_, sym)
102+
# qsym = QuoteNode(sym)
103+
# argnames = [Symbol(:arg_, string(i)) for i = 1:length(AT)]
104+
# TAT = Expr(:tuple, [parametric_type_to_expr(t) for t in AT]...)
105+
# def = :($fname($(argnames...)) = ccall($qsym, $RT, $TAT, $(argnames...)))
106+
# f = Core.eval(Core.Compiler, def)
107+
# compiled_calls[(qsym, RT, Core.svec(AT...), Core.Compiler)] = f
108+
# precompile(f, AT)
109+
# end
79110
end
80111

81112
include("precompile.jl")

src/commands.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ function maybe_step_through_wrapper!(@nospecialize(recurse), frame::Frame)
213213
last = stmts[end-1]
214214
isexpr(last, :(=)) && (last = last.args[2])
215215
is_kw = isa(scope, Method) && startswith(String(Base.unwrap_unionall(Base.unwrap_unionall(scope.sig).parameters[1]).name.name), "#kw")
216-
if is_kw || isexpr(last, :call) && any(isequal(Core.SlotNumber(1)), last.args)
216+
if is_kw || isexpr(last, :call) && any(isequal(SlotNumber(1)), last.args)
217217
# If the last expr calls #self# or passes it to an implementation method,
218218
# this is a wrapper function that we might want to step through
219219
while frame.pc != length(stmts)-1

src/interpret.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ function bypass_builtins(frame, call_expr, pc)
178178
if isa(tme, Compiled)
179179
fargs = collect_args(frame, call_expr)
180180
f = to_function(fargs[1])
181-
if parentmodule(f) === JuliaInterpreter.CompiledCalls
181+
fmod = parentmodule(f)
182+
if fmod === JuliaInterpreter.CompiledCalls || fmod === Core.Compiler
182183
return Some{Any}(Base.invokelatest(f, fargs[2:end]...))
183184
else
184185
return Some{Any}(f(fargs[2:end]...))

src/optimize.jl

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
const calllike = Set([:call, :foreigncall])
22

3+
const compiled_calls = Dict{Any,Any}()
4+
35
function extract_inner_call!(stmt, idx, once::Bool=false)
46
isa(stmt, Expr) || return nothing
57
(stmt.head == :toplevel || stmt.head == :thunk) && return nothing
@@ -127,8 +129,10 @@ which this will run) and ensures that no statement includes nested `:call` expre
127129
"""
128130
function optimize!(code::CodeInfo, scope)
129131
mod = moduleof(scope)
132+
evalmod = mod == Core.Compiler ? Core.Compiler : CompiledCalls
130133
sparams = scope isa Method ? Symbol[sparam_syms(scope)...] : Symbol[]
131134
code.inferred && error("optimization of inferred code not implemented")
135+
replace_coretypes!(code)
132136
# TODO: because of builtins.jl, for CodeInfos like
133137
# %1 = Core.apply_type
134138
# %2 = (%1)(args...)
@@ -163,7 +167,7 @@ function optimize!(code::CodeInfo, scope)
163167
ustr = replace(string(uuid), '-'=>'_')
164168
methname = Symbol("llvmcall_", ustr)
165169
nargs = length(stmt.args)-4
166-
delete_idx = build_compiled_call!(stmt, methname, Base.llvmcall, stmt.args[2:4], code, idx, nargs, sparams)
170+
delete_idx = build_compiled_call!(stmt, methname, Base.llvmcall, stmt.args[2:4], code, idx, nargs, sparams, evalmod)
167171
push!(foreigncalls_idx, idx)
168172
append!(delete_idxs, delete_idx)
169173
end
@@ -189,7 +193,7 @@ function optimize!(code::CodeInfo, scope)
189193
ustr = replace(string(uuid), '-'=>'_')
190194
methname = Symbol("ccall", '_', f, '_', ustr)
191195
nargs = stmt.args[5]
192-
delete_idx = build_compiled_call!(stmt, methname, :ccall, stmt.args[1:3], code, idx, nargs, sparams)
196+
delete_idx = build_compiled_call!(stmt, methname, :ccall, stmt.args[1:3], code, idx, nargs, sparams, evalmod)
193197
push!(foreigncalls_idx, idx)
194198
append!(delete_idxs, delete_idx)
195199
end
@@ -239,14 +243,15 @@ end
239243
function parametric_type_to_expr(t::Type)
240244
t isa Core.TypeofBottom && return t
241245
t isa UnionAll && (t = t.body)
242-
if t <: Vararg
246+
if t <: Vararg
243247
return Expr(:(...), t.parameters[1])
244248
end
245249
return t.hasfreetypevars ? Expr(:curly, t.name.name, ((tv-> tv isa TypeVar ? tv.name : tv).(t.parameters))...) : t
246250
end
247251

248252
# Handle :llvmcall & :foreigncall (issue #28)
249-
function build_compiled_call!(stmt, methname, fcall, typargs, code, idx, nargs, sparams)
253+
function build_compiled_call!(stmt, methname, fcall, typargs, code, idx, nargs, sparams, evalmod)
254+
TVal = evalmod == Core.Compiler ? Core.Compiler.Val : Val
250255
argnames = Any[Symbol("arg", string(i)) for i = 1:nargs]
251256
delete_idx = Int[]
252257
if fcall == :ccall
@@ -287,48 +292,74 @@ function build_compiled_call!(stmt, methname, fcall, typargs, code, idx, nargs,
287292
cfunc, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])
288293
args = stmt.args[5:end]
289294
end
290-
if isa(cfunc, Expr)
295+
if isa(cfunc, Expr) # specification by tuple, e.g., (:clock, "libc")
291296
cfunc = eval(cfunc)
292297
end
293298
if isa(cfunc, Symbol)
294299
cfunc = QuoteNode(cfunc)
295300
end
296-
if fcall == :ccall
297-
ArgType = Expr(:tuple, [parametric_type_to_expr(t) for t in ArgType]...)
298-
end
299301
if isa(RetType, SimpleVector)
300302
@assert length(RetType) == 1
301303
RetType = RetType[1]
302304
end
303-
RetType = parametric_type_to_expr(RetType)
304-
wrapargs = copy(argnames)
305-
for sparam in sparams
306-
push!(wrapargs, :(::Val{$sparam}))
307-
end
308-
if stmt.args[4] == :(:llvmcall)
309-
def = :(
310-
function $methname($(wrapargs...)) where {$(sparams...)}
311-
return $fcall($cfunc, llvmcall, $RetType, $ArgType, $(argnames...))
312-
end)
313-
elseif stmt.args[4] == :(:stdcall)
314-
def = :(
315-
function $methname($(wrapargs...)) where {$(sparams...)}
316-
return $fcall($cfunc, stdcall, $RetType, $ArgType, $(argnames...))
317-
end)
318-
else
319-
def = :(
320-
function $methname($(wrapargs...)) where {$(sparams...)}
321-
return $fcall($cfunc, $RetType, $ArgType, $(argnames...))
322-
end)
305+
cc_key = (cfunc, RetType, ArgType, evalmod) # compiled call key
306+
f = get(compiled_calls, cc_key, nothing)
307+
if f === nothing
308+
if fcall == :ccall
309+
ArgType = Expr(:tuple, [parametric_type_to_expr(t) for t in ArgType]...)
310+
end
311+
RetType = parametric_type_to_expr(RetType)
312+
wrapargs = copy(argnames)
313+
for sparam in sparams
314+
push!(wrapargs, :(::$TVal{$sparam}))
315+
end
316+
if stmt.args[4] == :(:llvmcall)
317+
def = :(
318+
function $methname($(wrapargs...)) where {$(sparams...)}
319+
return $fcall($cfunc, llvmcall, $RetType, $ArgType, $(argnames...))
320+
end)
321+
elseif stmt.args[4] == :(:stdcall)
322+
def = :(
323+
function $methname($(wrapargs...)) where {$(sparams...)}
324+
return $fcall($cfunc, stdcall, $RetType, $ArgType, $(argnames...))
325+
end)
326+
else
327+
def = :(
328+
function $methname($(wrapargs...)) where {$(sparams...)}
329+
return $fcall($cfunc, $RetType, $ArgType, $(argnames...))
330+
end)
331+
end
332+
f = Core.eval(evalmod, def)
333+
compiled_calls[cc_key] = f
323334
end
324-
f = Core.eval(CompiledCalls, def)
325335
stmt.args[1] = QuoteNode(f)
326336
stmt.head = :call
327337
deleteat!(stmt.args, 2:length(stmt.args))
328338
append!(stmt.args, args)
329339
for i in 1:length(sparams)
330-
push!(stmt.args, :($Val($(Expr(:static_parameter, i)))))
340+
push!(stmt.args, :($TVal($(Expr(:static_parameter, i)))))
331341
end
332342
return delete_idx
333343
end
334344

345+
function replace_coretypes!(src; rev::Bool=false)
346+
if isa(src, CodeInfo)
347+
replace_coretypes_list!(src.code; rev=rev)
348+
elseif isa(src, Expr)
349+
replace_coretypes_list!(src.args; rev=rev)
350+
end
351+
return src
352+
end
353+
354+
function replace_coretypes_list!(list; rev::Bool)
355+
for (i, stmt) in enumerate(list)
356+
if isa(stmt, rev ? SSAValue : Core.SSAValue)
357+
list[i] = rev ? Core.SSAValue(stmt.id) : SSAValue(stmt.id)
358+
elseif isa(stmt, rev ? SlotNumber : Core.SlotNumber)
359+
list[i] = rev ? Core.SlotNumber(stmt.id) : SlotNumber(stmt.id)
360+
elseif isa(stmt, Expr)
361+
replace_coretypes!(stmt; rev=rev)
362+
end
363+
end
364+
return nothing
365+
end

src/types.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@ struct NewSSAValue
1111
id::Int
1212
end
1313

14+
# Our own replacements for Core types. We need to do this to ensure we can tell the difference
15+
# between "data" (Core types) and "code" (our types) if we step into Core.Compiler
16+
struct SSAValue
17+
id::Int
18+
end
19+
struct SlotNumber
20+
id::Int
21+
end
22+
1423
# Breakpoint support
1524
truecondition(frame) = true
1625
falsecondition(frame) = false
@@ -70,7 +79,7 @@ function FrameCode(scope, src::CodeInfo; generator=false, optimize=true)
7079
if optimize
7180
src, methodtables = optimize!(copy_codeinfo(src), scope)
7281
else
73-
src = copy_codeinfo(src)
82+
src = replace_coretypes!(copy_codeinfo(src))
7483
methodtables = Vector{Union{Compiled,TypeMapEntry}}(undef, length(src.code))
7584
end
7685
breakpoints = Vector{BreakpointState}(undef, length(src.code))

src/utils.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,32 @@ function find_used(code::CodeInfo)
6161
used = BitSet()
6262
stmts = code.code
6363
for stmt in stmts
64-
Core.Compiler.scan_ssa_use!(push!, used, stmt)
65-
if isexpr(stmt, :struct_type) # this one is missed
64+
scan_ssa_use!(used, stmt)
65+
if isexpr(stmt, :struct_type) # this one is missed by Core.Compiler.userefs
6666
for a in stmt.args
67-
Core.Compiler.scan_ssa_use!(push!, used, a)
67+
scan_ssa_use!(used, a)
6868
end
6969
end
7070
end
7171
return used
7272
end
7373

74+
function scan_ssa_use!(used::BitSet, @nospecialize(stmt))
75+
if isa(stmt, SSAValue)
76+
push!(used, stmt.id)
77+
end
78+
iter = Core.Compiler.userefs(stmt)
79+
iterval = Core.Compiler.iterate(iter)
80+
while iterval !== nothing
81+
useref, state = iterval
82+
val = Core.Compiler.getindex(useref)
83+
if isa(val, SSAValue)
84+
push!(used, val.id)
85+
end
86+
iterval = Core.Compiler.iterate(iter, state)
87+
end
88+
end
89+
7490
## Predicates
7591

7692
is_goto_node(@nospecialize(node)) = isa(node, GotoNode) || isexpr(node, :gotoifnot)
@@ -265,7 +281,9 @@ function print_framecode(io::IO, framecode::FrameCode; pc=0, range=1:nstatements
265281
offset = lineoffset(framecode)
266282
ndline = isempty(lt) ? 0 : ndigits(getline(lt[end]) + offset)
267283
nullline = " "^ndline
268-
code = framecode_lines(framecode)
284+
src = copy_codeinfo(framecode.src)
285+
replace_coretypes!(src; rev=true)
286+
code = framecode_lines(src)
269287
isfirst = true
270288
for (stmtidx, stmtline) in enumerate(code)
271289
stmtidx range || continue

test/interpret.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ f113(;x) = x
337337
end
338338
frame = JuliaInterpreter.enter_call(f_multi, 1)
339339
nlocals = length(frame.framedata.locals)
340-
@test_throws UndefVarError JuliaInterpreter.lookup_var(frame, Core.SlotNumber(nlocals))
340+
@test_throws UndefVarError JuliaInterpreter.lookup_var(frame, JuliaInterpreter.SlotNumber(nlocals))
341341
stack = [frame]
342342
locals = JuliaInterpreter.locals(frame)
343343
@test length(locals) == 2
@@ -383,7 +383,7 @@ end
383383
end
384384
end
385385
@test @interpret(test_never_different(10)) === nothing
386-
386+
387387
end
388388

389389
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/130
@@ -516,3 +516,10 @@ end
516516
# Test exception type for undefined variables
517517
f() = s = s + 1
518518
@test_throws UndefVarError @interpret f()
519+
520+
# Handling of SSAValues
521+
function f()
522+
z = [Core.SSAValue(5),]
523+
repr(z[1])
524+
end
525+
@test @interpret f() == f()

0 commit comments

Comments
 (0)