Skip to content

Commit e83b317

Browse files
vtjnashaviatesk
andauthored
inference: propagate variable changes to all exception frames (#42081)
* inference: propagate variable changes to all exception frames Fix #42022 * Update test/compiler/inference.jl * Update test/compiler/inference.jl Co-authored-by: Shuhei Kadowaki <[email protected]> * fixup! inference: propagate variable changes to all exception frames Co-authored-by: Shuhei Kadowaki <[email protected]>
1 parent 03e7b23 commit e83b317

File tree

3 files changed

+156
-35
lines changed

3 files changed

+156
-35
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,18 +1764,16 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
17641764
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
17651765
while frame.pc´´ <= n
17661766
# make progress on the active ip set
1767-
local pc::Int = frame.pc´´ # current program-counter
1767+
local pc::Int = frame.pc´´
17681768
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
17691769
#print(pc,": ",s[pc],"\n")
17701770
local pc´::Int = pc + 1 # next program-counter (after executing instruction)
17711771
if pc == frame.pc´´
1772-
# need to update pc´´ to point at the new lowest instruction in W
1773-
min_pc = _bits_findnext(W.bits, pc + 1)
1774-
frame.pc´´ = min_pc == -1 ? n + 1 : min_pc
1772+
# want to update pc´´ to point at the new lowest instruction in W
1773+
frame.pc´´ = pc´
17751774
end
17761775
delete!(W, pc)
17771776
frame.currpc = pc
1778-
frame.cur_hand = frame.handler_at[pc]
17791777
edges = frame.stmt_edges[pc]
17801778
edges === nothing || empty!(edges)
17811779
frame.stmt_info[pc] = nothing
@@ -1817,7 +1815,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18171815
pc´ = l
18181816
else
18191817
# general case
1820-
frame.handler_at[l] = frame.cur_hand
18211818
changes_else = changes
18221819
if isa(condt, Conditional)
18231820
changes_else = conditional_changes(changes_else, condt.elsetype, condt.var)
@@ -1877,7 +1874,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18771874
elseif hd === :enter
18781875
stmt = stmt::Expr
18791876
l = stmt.args[1]::Int
1880-
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
18811877
# propagate type info to exception handler
18821878
old = states[l]
18831879
newstate_catch = stupdate!(old, changes)
@@ -1889,12 +1885,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18891885
states[l] = newstate_catch
18901886
end
18911887
typeassert(states[l], VarTable)
1892-
frame.handler_at[l] = frame.cur_hand
18931888
elseif hd === :leave
1894-
stmt = stmt::Expr
1895-
for i = 1:((stmt.args[1])::Int)
1896-
frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second
1897-
end
18981889
else
18991890
if hd === :(=)
19001891
stmt = stmt::Expr
@@ -1928,16 +1919,22 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
19281919
ssavaluetypes[pc] = t
19291920
end
19301921
end
1931-
if frame.cur_hand !== nothing && isa(changes, StateUpdate)
1932-
# propagate new type info to exception handler
1933-
# the handling for Expr(:enter) propagates all changes from before the try/catch
1934-
# so this only needs to propagate any changes
1935-
l = frame.cur_hand.first::Int
1936-
if stupdate1!(states[l]::VarTable, changes::StateUpdate) !== false
1937-
if l < frame.pc´´
1938-
frame.pc´´ = l
1922+
if isa(changes, StateUpdate)
1923+
let cur_hand = frame.handler_at[pc], l, enter
1924+
while cur_hand != 0
1925+
enter = frame.src.code[cur_hand]
1926+
l = (enter::Expr).args[1]::Int
1927+
# propagate new type info to exception handler
1928+
# the handling for Expr(:enter) propagates all changes from before the try/catch
1929+
# so this only needs to propagate any changes
1930+
if stupdate1!(states[l]::VarTable, changes::StateUpdate) !== false
1931+
if l < frame.pc´´
1932+
frame.pc´´ = l
1933+
end
1934+
push!(W, l)
1935+
end
1936+
cur_hand = frame.handler_at[cur_hand]
19391937
end
1940-
push!(W, l)
19411938
end
19421939
end
19431940
end
@@ -1950,7 +1947,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
19501947
end
19511948

19521949
pc´ > n && break # can't proceed with the fast-path fall-through
1953-
frame.handler_at[pc´] = frame.cur_hand
19541950
newstate = stupdate!(states[pc´], changes)
19551951
if isa(stmt, GotoNode) && frame.pc´´ < pc´
19561952
# if we are processing a goto node anyways,
@@ -1961,7 +1957,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
19611957
states[pc´] = newstate
19621958
end
19631959
push!(W, pc´)
1964-
pc = frame.pc´´
1960+
break
19651961
elseif newstate !== nothing
19661962
states[pc´] = newstate
19671963
pc = pc´
@@ -1971,6 +1967,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
19711967
break
19721968
end
19731969
end
1970+
frame.pc´´ = _bits_findnext(W.bits, frame.pc´´)::Int # next program-counter
19741971
end
19751972
frame.dont_work_on_me = false
19761973
nothing

base/compiler/inferencestate.jl

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ mutable struct InferenceState
2828
pc´´::LineNum
2929
nstmts::Int
3030
# current exception handler info
31-
cur_hand #::Union{Nothing, Pair{LineNum, prev_handler}}
32-
handler_at::Vector{Any}
33-
n_handlers::Int
31+
handler_at::Vector{LineNum}
3432
# ssavalue sparsity and restart info
3533
ssavalue_uses::Vector{BitSet}
3634
throw_blocks::BitSet
@@ -86,25 +84,21 @@ mutable struct InferenceState
8684
throw_blocks = find_throw_blocks(code)
8785

8886
# exception handlers
89-
cur_hand = nothing
90-
handler_at = Any[ nothing for i=1:n ]
91-
n_handlers = 0
92-
93-
W = BitSet()
94-
push!(W, 1) #initial pc to visit
87+
ip = BitSet()
88+
handler_at = compute_trycatch(src.code, ip)
89+
push!(ip, 1)
9590

9691
mod = isa(def, Method) ? def.module : def
97-
9892
valid_worlds = WorldRange(src.min_world,
9993
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
94+
10095
frame = new(
10196
InferenceParams(interp), result, linfo,
10297
sp, slottypes, mod, 0,
10398
IdSet{InferenceState}(), IdSet{InferenceState}(),
10499
src, get_world_counter(interp), valid_worlds,
105100
nargs, s_types, s_edges, stmt_info,
106-
Union{}, W, 1, n,
107-
cur_hand, handler_at, n_handlers,
101+
Union{}, ip, 1, n, handler_at,
108102
ssavalue_uses, throw_blocks,
109103
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
110104
Vector{InferenceState}(), # callers_in_cycle
@@ -118,6 +112,91 @@ mutable struct InferenceState
118112
end
119113
end
120114

115+
function compute_trycatch(code::Vector{Any}, ip::BitSet)
116+
# The goal initially is to record the frame like this for the state at exit:
117+
# 1: (enter 3) # == 0
118+
# 3: (expr) # == 1
119+
# 3: (leave 1) # == 1
120+
# 4: (expr) # == 0
121+
# then we can find all trys by walking backwards from :enter statements,
122+
# and all catches by looking at the statement after the :enter
123+
n = length(code)
124+
empty!(ip)
125+
ip.offset = 0 # for _bits_findnext
126+
push!(ip, n + 1)
127+
handler_at = fill(0, n)
128+
129+
# start from all :enter statements and record the location of the try
130+
for pc = 1:n
131+
stmt = code[pc]
132+
if isexpr(stmt, :enter)
133+
l = stmt.args[1]::Int
134+
handler_at[pc + 1] = pc
135+
push!(ip, pc + 1)
136+
handler_at[l] = pc
137+
push!(ip, l)
138+
end
139+
end
140+
141+
# now forward those marks to all :leave statements
142+
pc´´ = 0
143+
while true
144+
# make progress on the active ip set
145+
pc = _bits_findnext(ip.bits, pc´´)::Int
146+
pc > n && break
147+
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
148+
pc´ = pc + 1 # next program-counter (after executing instruction)
149+
if pc == pc´´
150+
pc´´ = pc´
151+
end
152+
delete!(ip, pc)
153+
cur_hand = handler_at[pc]
154+
@assert cur_hand != 0 "unbalanced try/catch"
155+
stmt = code[pc]
156+
if isa(stmt, GotoNode)
157+
pc´ = stmt.label
158+
elseif isa(stmt, GotoIfNot)
159+
l = stmt.dest::Int
160+
if handler_at[l] != cur_hand
161+
@assert handler_at[l] == 0 "unbalanced try/catch"
162+
handler_at[l] = cur_hand
163+
if l < pc´´
164+
pc´´ = l
165+
end
166+
push!(ip, l)
167+
end
168+
elseif isa(stmt, ReturnNode)
169+
@assert !isdefined(stmt, :val) "unbalanced try/catch"
170+
break
171+
elseif isa(stmt, Expr)
172+
head = stmt.head
173+
if head === :enter
174+
cur_hand = pc
175+
elseif head === :leave
176+
l = stmt.args[1]::Int
177+
for i = 1:l
178+
cur_hand = handler_at[cur_hand]
179+
end
180+
cur_hand == 0 && break
181+
end
182+
end
183+
184+
pc´ > n && break # can't proceed with the fast-path fall-through
185+
if handler_at[pc´] != cur_hand
186+
@assert handler_at[pc´] == 0 "unbalanced try/catch"
187+
handler_at[pc´] = cur_hand
188+
elseif !in(pc´, ip)
189+
break # already visited
190+
end
191+
pc = pc´
192+
end
193+
end
194+
195+
@assert first(ip) == n + 1
196+
return handler_at
197+
end
198+
199+
121200
"""
122201
Iterate through all callers of the given InferenceState in the abstract
123202
interpretation stack (including the given InferenceState itself), vising

test/compiler/inference.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3451,3 +3451,48 @@ end
34513451
f41908(x::Complex{T}) where {String<:T<:String} = 1
34523452
g41908() = f41908(Any[1][1])
34533453
@test only(Base.return_types(g41908, ())) <: Int
3454+
3455+
# issue #42022
3456+
let x = Tuple{Int,Any}[
3457+
#= 1=# (0, Expr(:(=), Core.SlotNumber(3), 1))
3458+
#= 2=# (0, Expr(:enter, 18))
3459+
#= 3=# (2, Expr(:(=), Core.SlotNumber(3), 2.0))
3460+
#= 4=# (2, Expr(:enter, 12))
3461+
#= 5=# (4, Expr(:(=), Core.SlotNumber(3), '3'))
3462+
#= 6=# (4, Core.GotoIfNot(Core.SlotNumber(2), 9))
3463+
#= 7=# (4, Expr(:leave, 2))
3464+
#= 8=# (0, Core.ReturnNode(1))
3465+
#= 9=# (4, Expr(:call, GlobalRef(Main, :throw)))
3466+
#=10=# (4, Expr(:leave, 1))
3467+
#=11=# (2, Core.GotoNode(16))
3468+
#=12=# (4, Expr(:leave, 1))
3469+
#=13=# (2, Expr(:(=), Core.SlotNumber(4), Expr(:the_exception)))
3470+
#=14=# (2, Expr(:call, GlobalRef(Main, :rethrow)))
3471+
#=15=# (2, Expr(:pop_exception, Core.SSAValue(4)))
3472+
#=16=# (2, Expr(:leave, 1))
3473+
#=17=# (0, Core.GotoNode(22))
3474+
#=18=# (2, Expr(:leave, 1))
3475+
#=19=# (0, Expr(:(=), Core.SlotNumber(5), Expr(:the_exception)))
3476+
#=20=# (0, nothing)
3477+
#=21=# (0, Expr(:pop_exception, Core.SSAValue(2)))
3478+
#=22=# (0, Core.ReturnNode(Core.SlotNumber(3)))
3479+
]
3480+
handler_at = Core.Compiler.compute_trycatch(last.(x), Core.Compiler.BitSet())
3481+
@test handler_at == first.(x)
3482+
end
3483+
3484+
@test only(Base.return_types((Bool,)) do y
3485+
x = 1
3486+
try
3487+
x = 2.0
3488+
try
3489+
x = '3'
3490+
y ? (return 1) : throw()
3491+
catch ex1
3492+
rethrow()
3493+
end
3494+
catch ex2
3495+
nothing
3496+
end
3497+
return x
3498+
end) === Union{Int, Float64, Char}

0 commit comments

Comments
 (0)