Skip to content

Commit 022865e

Browse files
authored
Don't insert truncation within phi blocks (#162)
* Don't insert truncation within phi blocks * more efficient search * clearer test * change test to check all ading phi nodes * Start search at start
1 parent 10cef62 commit 022865e

File tree

3 files changed

+68
-14
lines changed

3 files changed

+68
-14
lines changed

src/codegen/forward_demand.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
240240
if argorder != order
241241
@assert order < argorder
242242
return get!(truncation_map, arg=>order) do
243+
# identify where to insert. Must be after phi blocks
244+
pos = SSAValue(find_end_of_phi_block(ir, arg.id))
243245
if order == 0
244-
insert_node!(ir, arg, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true)
246+
insert_node!(ir, pos, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true)
245247
else
246-
insert_node!(ir, arg, NewInstruction(Expr(:call, truncate, arg, Val{order}()), Any), #=attach_after=#true)
248+
insert_node!(ir, pos, NewInstruction(Expr(:call, truncate, arg, Val{order}()), Any), #=attach_after=#true)
247249
end
248250
end
249251
end

src/stage1/compiler_utils.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,29 @@ Base.lastindex(x::Core.Compiler.InstructionStream) =
6464
if isdefined(Core.Compiler, :CallInfo)
6565
Base.convert(::Type{Core.Compiler.CallInfo}, ::Nothing) = Core.Compiler.NoCallInfo()
6666
end
67+
68+
69+
"""
70+
find_end_of_phi_block(ir, start_search_idx)
71+
72+
Finds the last index within the same basic block, on or after the `start_search_idx` which is not within a phi block.
73+
A phi-block is a run on PhiNodes or nothings that must be the first statements within the basic block.
74+
75+
If `start_search_idx` is not within a phi block to begin with, then just returns `start_search_idx`
76+
"""
77+
function find_end_of_phi_block(ir, start_search_idx::Int)
78+
# Short-cut for early exit:
79+
stmt = ir.stmts[start_search_idx][:inst]
80+
stmt !== nothing && !isa(stmt, PhiNode) && return start_search_idx
81+
82+
# Actually going to have to go digging throught the IR to out if were are in a phi block
83+
# TODO: this is not so efficient. maybe preconstruct CFG then use block_for_inst?
84+
bb=CC.block_for_inst(ir.cfg, start_search_idx)
85+
end_search_idx=ir.cfg.blocks[bb].stmts[end]
86+
for idx in (start_search_idx):(end_search_idx-1)
87+
stmt = ir.stmts[idx+1][:inst]
88+
# next statment is no longer in a phi block, so safe to insert
89+
stmt !== nothing && !isa(stmt, PhiNode) && return idx
90+
end
91+
return end_search_idx
92+
end

test/stage2_fwd.jl

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,23 @@ module stage2_fwd
4343
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
4444
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10f0, (1.0,)))
4545
end
46+
end
47+
4648

49+
module forward_diff_no_inf # todo: move this to a seperate file
50+
using Diffractor, Test
51+
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
52+
identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa]
53+
function identity_transform!(ir, arg::Core.Argument, order)
54+
return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any))
55+
end
56+
4757
@testset "Constructors in forward_diff_no_inf!" begin
4858
struct Bar148
4959
v
5060
end
5161
foo_148(x) = Bar148(x)
5262

53-
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
54-
identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa]
55-
function identity_transform!(ir, arg::Core.Argument, order)
56-
return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any))
57-
end
58-
5963
ir = first(only(Base.code_ircode(foo_148, Tuple{Float64})))
6064
Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!)
6165
ir2 = Core.Compiler.compact!(ir)
@@ -67,17 +71,39 @@ module stage2_fwd
6771
@eval global _coeff::Float64=24.5
6872
plus_a_global(x) = x + _coeff
6973

70-
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
71-
identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa]
72-
function identity_transform!(ir, arg::Core.Argument, order)
73-
return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any))
74-
end
75-
7674
ir = first(only(Base.code_ircode(plus_a_global, Tuple{Float64})))
7775
Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!)
7876
ir2 = Core.Compiler.compact!(ir)
7977
Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst globals correctly
8078
f = Core.OpaqueClosure(ir2; do_compile=false)
8179
@test f(3.5) == 28.0
8280
end
81+
82+
@testset "runs of phi nodes" begin
83+
function phi_run(x::Float64)
84+
a = 2.0
85+
b = 2.0
86+
if (@noinline rand()) < 0 # this branch will never actually be taken
87+
a = -100.0
88+
b = 200.0
89+
end
90+
return x - a + b
91+
end
92+
93+
input_ir = first(only(Base.code_ircode(phi_run, Tuple{Float64})))
94+
ir = copy(input_ir)
95+
#Workout where to diff to trigger error
96+
diff_ssa = Core.SSAValue[]
97+
for idx in 1:length(ir.stmts)
98+
if ir.stmts[idx][:inst] isa Core.PhiNode
99+
push!(diff_ssa, Core.SSAValue(idx))
100+
end
101+
end
102+
103+
Diffractor.forward_diff_no_inf!(ir, diff_ssa .=> 1; transform! = identity_transform!)
104+
ir2 = Core.Compiler.compact!(ir)
105+
Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158)
106+
f = Core.OpaqueClosure(ir2; do_compile=false)
107+
@test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly
108+
end
83109
end

0 commit comments

Comments
 (0)