Skip to content

Commit d319168

Browse files
vtjnashaviatesk
andauthored
Remove buggy linearization pass (#604)
There is only one bug in the base linearization pass that needs to be handled explicitly, otherwise linearization is fully guaranteed for IR. --------- Co-authored-by: Shuhei Kadowaki <[email protected]>
1 parent 1efae18 commit d319168

File tree

4 files changed

+51
-154
lines changed

4 files changed

+51
-154
lines changed

src/interpret.jl

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,6 @@ function lookup_var(frame, slot::SlotNumber)
99
throw(UndefVarError(frame.framecode.src.slotnames[slot.id]))
1010
end
1111

12-
function lookup_expr(frame, e::Expr)
13-
head = e.head
14-
head === :the_exception && return frame.framedata.last_exception[]
15-
if head === :static_parameter
16-
arg = e.args[1]::Int
17-
if isassigned(frame.framedata.sparams, arg)
18-
return frame.framedata.sparams[arg]
19-
else
20-
syms = sparam_syms(frame.framecode.scope::Method)
21-
throw(UndefVarError(syms[arg]))
22-
end
23-
end
24-
head === :boundscheck && length(e.args) == 0 && return true
25-
error("invalid lookup expr ", e)
26-
end
27-
2812
"""
2913
rhs = @lookup(frame, node)
3014
rhs = @lookup(mod, frame, node)
@@ -67,6 +51,32 @@ macro lookup(args...)
6751
end
6852
end
6953

54+
function lookup_expr(frame, e::Expr)
55+
head = e.head
56+
head === :the_exception && return frame.framedata.last_exception[]
57+
if head === :static_parameter
58+
arg = e.args[1]::Int
59+
if isassigned(frame.framedata.sparams, arg)
60+
return frame.framedata.sparams[arg]
61+
else
62+
syms = sparam_syms(frame.framecode.scope::Method)
63+
throw(UndefVarError(syms[arg]))
64+
end
65+
end
66+
head === :boundscheck && length(e.args) == 0 && return true
67+
if head === :call
68+
f = @lookup frame e.args[1]
69+
if (@static VERSION < v"1.11.0-DEV.1180" && true) && f === Core.svec
70+
# work around for a linearization bug in Julia (https://github.com/JuliaLang/julia/pull/52497)
71+
return f(Any[@lookup(frame, e.args[i]) for i in 2:length(e.args)]...)
72+
elseif f === Core.tuple
73+
# handling for ccall literal syntax
74+
return f(Any[@lookup(frame, e.args[i]) for i in 2:length(e.args)]...)
75+
end
76+
end
77+
error("invalid lookup expr ", e)
78+
end
79+
7080
# This is used only for new struct/abstract/primitive nodes.
7181
# The most important issue is that in these expressions, :call Exprs can be nested,
7282
# and hence our re-use of the `callargs` field of Frame would introduce
@@ -91,18 +101,26 @@ function lookup_or_eval(@nospecialize(recurse), frame, @nospecialize(node))
91101
if ex.head === :call
92102
f = ex.args[1]
93103
if f === Core.svec
94-
return Core.svec(ex.args[2:end]...)
104+
popfirst!(ex.args)
105+
return Core.svec(ex.args...)
95106
elseif f === Core.apply_type
96-
return Core.apply_type(ex.args[2:end]...)
97-
elseif f === Core.typeof
98-
return Core.typeof(ex.args[2])
99-
elseif f === Base.getproperty
107+
popfirst!(ex.args)
108+
return Core.apply_type(ex.args...)
109+
elseif f === typeof && length(ex.args) == 2
110+
return typeof(ex.args[2])
111+
elseif f === typeassert && length(ex.args) == 3
112+
return typeassert(ex.args[2], ex.args[3])
113+
elseif f === Base.getproperty && length(ex.args) == 3
100114
return Base.getproperty(ex.args[2], ex.args[3])
115+
elseif f === Core.Compiler.Val && length(ex.args) == 2
116+
return Core.Compiler.Val(ex.args[2])
117+
elseif f === Val && length(ex.args) == 2
118+
return Val(ex.args[2])
101119
else
102-
Base.invokelatest(error, "unknown call f ", f)
120+
Base.invokelatest(error, "unknown call f introduced by ccall lowering ", f)
103121
end
104122
else
105-
error("unknown expr ", ex)
123+
return lookup_expr(frame, ex)
106124
end
107125
elseif isa(node, Int) || isa(node, Number) # Number is slow, requires subtyping
108126
return node

src/optimize.jl

Lines changed: 6 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,5 @@
1-
const calllike = (:call, :foreigncall)
2-
31
const compiled_calls = Dict{Any,Any}()
42

5-
function extract_inner_call!(stmt::Expr, idx, once::Bool=false)
6-
(stmt.head === :toplevel || stmt.head === :thunk) && return nothing
7-
once |= stmt.head calllike
8-
for (i, a) in enumerate(stmt.args)
9-
isa(a, Expr) || continue
10-
# Make sure we don't "damage" special syntax that requires literals
11-
if i == 1 && stmt.head === :foreigncall
12-
continue
13-
end
14-
if i == 2 && stmt.head === :call && stmt.args[1] === :cglobal
15-
continue
16-
end
17-
ret = extract_inner_call!(a, idx, once) # doing this first extracts innermost calls
18-
ret !== nothing && return ret
19-
iscalllike = a.head calllike
20-
if once && iscalllike
21-
stmt.args[i] = NewSSAValue(idx)
22-
return a
23-
end
24-
end
25-
return nothing
26-
end
27-
28-
function replace_ssa(stmt::Expr, ssalookup)
29-
return Expr(stmt.head, Any[
30-
if isa(a, SSAValue)
31-
SSAValue(ssalookup[a.id])
32-
elseif isa(a, NewSSAValue)
33-
SSAValue(a.id)
34-
elseif isa(a, Expr)
35-
replace_ssa(a, ssalookup)
36-
else
37-
a
38-
end
39-
for a in stmt.args
40-
]...)
41-
end
42-
43-
function renumber_ssa!(stmts::Vector{Any}, ssalookup)
44-
# When updating jumps, when lines get split into multiple lines
45-
# (see "Un-nest :call expressions" below), we need to jump to the first of them.
46-
# Consequently we use the previous "old-code" offset and add one.
47-
# Fixes #455.
48-
jumplookup(l, idx) = idx > 1 ? l[idx-1] + 1 : idx
49-
50-
for (i, stmt) in enumerate(stmts)
51-
if isa(stmt, GotoNode)
52-
stmts[i] = GotoNode(jumplookup(ssalookup, stmt.label))
53-
elseif isa(stmt, SSAValue)
54-
stmts[i] = SSAValue(ssalookup[stmt.id])
55-
elseif isa(stmt, NewSSAValue)
56-
stmts[i] = SSAValue(stmt.id)
57-
elseif isexpr(stmt, :enter)
58-
stmt.args[end] = jumplookup(ssalookup, stmt.args[1]::Int)
59-
elseif isa(stmt, Expr)
60-
stmts[i] = replace_ssa(stmt, ssalookup)
61-
elseif isa(stmt, GotoIfNot)
62-
cond = stmt.cond
63-
if isa(cond, SSAValue)
64-
cond = SSAValue(ssalookup[cond.id])
65-
end
66-
stmts[i] = GotoIfNot(cond, jumplookup(ssalookup, stmt.dest))
67-
elseif isa(stmt, ReturnNode)
68-
val = stmt.val
69-
if isa(val, SSAValue)
70-
stmts[i] = ReturnNode(SSAValue(ssalookup[val.id]))
71-
end
72-
elseif @static (isdefined(Core.IR, :EnterNode) && true) && isa(stmt, Core.IR.EnterNode)
73-
stmts[i] = Core.IR.EnterNode(jumplookup(ssalookup, stmt.catch_dest))
74-
end
75-
end
76-
return stmts
77-
end
78-
79-
function compute_ssa_mapping_delete_statements!(code::CodeInfo, stmts::Vector{Int})
80-
stmts = unique!(sort!(stmts))
81-
ssalookup = collect(1:length(codelocs(code)))
82-
cnt = 1
83-
for i in 1:length(stmts)
84-
start = stmts[i] + 1
85-
stop = i == length(stmts) ? length(codelocs(code)) : stmts[i+1]
86-
ssalookup[start:stop] .-= cnt
87-
cnt += 1
88-
end
89-
return ssalookup
90-
end
91-
923
# Pre-frame-construction lookup
934
function lookup_stmt(stmts, arg)
945
if isa(arg, SSAValue)
@@ -179,7 +90,8 @@ function optimize!(code::CodeInfo, scope)
17990

18091
# Replace :llvmcall and :foreigncall with compiled variants. See
18192
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/13#issuecomment-464880123
182-
foreigncalls_idx = Int[]
93+
# Insert the foreigncall wrappers at the updated idxs
94+
methodtables = Vector{Union{Compiled,DispatchableMethod}}(undef, length(code.code))
18395
for (idx, stmt) in enumerate(code.code)
18496
# Foregincalls can be rhs of assignments
18597
if isexpr(stmt, :(=))
@@ -192,47 +104,16 @@ function optimize!(code::CodeInfo, scope)
192104
if (arg1 === :llvmcall || lookup_stmt(code.code, arg1) === Base.llvmcall) && isempty(sparams) && scope isa Method
193105
# Call via `invokelatest` to avoid compiling it until we need it
194106
Base.invokelatest(build_compiled_llvmcall!, stmt, code, idx, evalmod)
195-
push!(foreigncalls_idx, idx)
107+
methodtables[idx] = Compiled()
196108
end
197109
elseif stmt.head === :foreigncall && scope isa Method
198110
# Call via `invokelatest` to avoid compiling it until we need it
199111
Base.invokelatest(build_compiled_foreigncall!, stmt, code, sparams, evalmod)
200-
push!(foreigncalls_idx, idx)
112+
methodtables[idx] = Compiled()
201113
end
202114
end
203115
end
204116

205-
## Un-nest :call expressions (so that there will be only one :call per line)
206-
# This will allow us to re-use args-buffers rather than having to allocate new ones each time.
207-
old_code, old_codelocs = code.code, codelocs(code)
208-
code.code = new_code = eltype(old_code)[]
209-
code.codelocs = new_codelocs = Int32[]
210-
ssainc = fill(1, length(old_code))
211-
for (i, stmt) in enumerate(old_code)
212-
loc = old_codelocs[i]
213-
if isa(stmt, Expr)
214-
inner = extract_inner_call!(stmt, length(new_code)+1)
215-
while inner !== nothing
216-
push!(new_code, inner)
217-
push!(new_codelocs, loc)
218-
ssainc[i] += 1
219-
inner = extract_inner_call!(stmt, length(new_code)+1)
220-
end
221-
end
222-
push!(new_code, stmt)
223-
push!(new_codelocs, loc)
224-
end
225-
# Fix all the SSAValues and GotoNodes
226-
ssalookup = cumsum(ssainc)
227-
renumber_ssa!(new_code, ssalookup)
228-
code.ssavaluetypes = length(new_code)
229-
230-
# Insert the foreigncall wrappers at the updated idxs
231-
methodtables = Vector{Union{Compiled,DispatchableMethod}}(undef, length(code.code))
232-
for idx in foreigncalls_idx
233-
methodtables[ssalookup[idx]] = Compiled()
234-
end
235-
236117
return code, methodtables
237118
end
238119

@@ -255,7 +136,7 @@ function parametric_type_to_expr(@nospecialize(t::Type))
255136
return t
256137
end
257138

258-
function build_compiled_llvmcall!(stmt::Expr, code, idx, evalmod)
139+
function build_compiled_llvmcall!(stmt::Expr, code::CodeInfo, idx::Int, evalmod::Module)
259140
# Run a mini-interpreter to extract the types
260141
framecode = FrameCode(CompiledCalls, code; optimize=false)
261142
frame = Frame(framecode, prepare_framedata(framecode, []))
@@ -292,9 +173,8 @@ function build_compiled_llvmcall!(stmt::Expr, code, idx, evalmod)
292173
append!(stmt.args, args)
293174
end
294175

295-
296176
# Handle :llvmcall & :foreigncall (issue #28)
297-
function build_compiled_foreigncall!(stmt::Expr, code, sparams::Vector{Symbol}, evalmod)
177+
function build_compiled_foreigncall!(stmt::Expr, code::CodeInfo, sparams::Vector{Symbol}, evalmod::Module)
298178
TVal = evalmod == Core.Compiler ? Core.Compiler.Val : Val
299179
cfunc, RetType, ArgType = lookup_stmt(code.code, stmt.args[1]), stmt.args[2], stmt.args[3]::SimpleVector
300180

src/types.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@ which will cause all calls to be evaluated via the interpreter.
66
struct Compiled end
77
Base.similar(::Compiled, sz) = Compiled() # to support similar(stack, 0)
88

9-
# A type used transiently in renumbering CodeInfo SSAValues (to distinguish a new SSAValue from an old one)
10-
struct NewSSAValue
11-
id::Int
12-
end
13-
149
# Our own replacements for Core types. We need to do this to ensure we can tell the difference
1510
# between "data" (Core types) and "code" (our types) if we step into Core.Compiler
1611
struct SSAValue

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ function scan_ssa_use!(used::BitSet, @nospecialize(stmt))
8585
while iterval !== nothing
8686
useref, state = iterval
8787
val = Core.Compiler.getindex(useref)
88+
if (@static VERSION < v"1.11.0-DEV.1180" && true) && isexpr(val, :call)
89+
# work around for a linearization bug in Julia (https://github.com/JuliaLang/julia/pull/52497)
90+
scan_ssa_use!(used, val)
91+
end
8892
if isa(val, SSAValue)
8993
push!(used, val.id)
9094
end

0 commit comments

Comments
 (0)