Skip to content

Commit 05810c2

Browse files
oxinaboxaviatesk
andauthored
When reinference is required use IR_FLAG_REFINED (#164)
Co-authored-by: Shuhei Kadowaki <[email protected]>
1 parent 64a740e commit 05810c2

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

src/codegen/forward_demand.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
248248
# identify where to insert. Must be after phi blocks
249249
pos = SSAValue(find_end_of_phi_block(ir, arg.id))
250250
if order == 0
251-
insert_node!(ir, pos, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true)
251+
insert_node!(ir, pos, NewInstruction(Expr(:call, primal, arg)), #=attach_after=#true)
252252
else
253-
insert_node!(ir, pos, NewInstruction(Expr(:call, truncate, arg, Val{order}()), Any), #=attach_after=#true)
253+
insert_node!(ir, pos, NewInstruction(Expr(:call, truncate, arg, Val{order}())), #=attach_after=#true)
254254
end
255255
end
256256
end
@@ -262,7 +262,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
262262
return transform!(ir, arg, order, maparg)
263263
elseif isa(arg, GlobalRef)
264264
@assert isconst(arg)
265-
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
265+
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg)))
266266
elseif isa(arg, QuoteNode)
267267
return ZeroBundle{order}(arg.value)
268268
end
@@ -300,6 +300,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
300300
# TODO: New PiNode that discriminates based on primal?
301301
inst[:inst] = maparg(stmt.val, SSAValue(ssa), order)
302302
inst[:type] = Any
303+
inst[:flag] |= CC.IR_FLAG_REFINED
303304
elseif isa(stmt, GlobalRef)
304305
if !isconst(stmt)
305306
# Non-const GlobalRefs need to need to be accessed as seperate statements
@@ -310,6 +311,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
310311
elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode)
311312
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
312313
inst[:type] = Any
314+
inst[:flag] |= CC.IR_FLAG_REFINED
313315
elseif isa(stmt, Expr) || isa(stmt, PhiNode) || isa(stmt, PhiCNode) ||
314316
isa(stmt, UpsilonNode) || isa(stmt, GotoIfNot) || isa(stmt, Argument)
315317
urs = userefs(stmt)
@@ -318,6 +320,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
318320
end
319321
inst[:inst] = urs[]
320322
inst[:type] = Any
323+
inst[:flag] |= CC.IR_FLAG_REFINED
321324
else
322325
val = ZeroBundle{order}(inst[:inst])
323326
inst[:inst] = val
@@ -336,10 +339,12 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met
336339
ir = compact!(ir)
337340

338341
for i = 1:length(ir.stmts)
339-
if ir[SSAValue(i)][:type] == Any
340-
# TODO: this flag should actually be being set at the insert site
341-
# and we should be filtering on if it is present rather than [:type]=Any
342-
ir[SSAValue(i)][:flag] |= CC.IR_FLAG_REFINED
342+
inst = ir[SSAValue(i)][:inst]
343+
if !isa(inst, ReturnNode) && ir[SSAValue(i)][:type] === Any
344+
if iszero(ir[SSAValue(i)][:flag] & CC.IR_FLAG_REFINED)
345+
@warn "IR_FLAG_REFINED Flag missed on statement" i inst
346+
ir[SSAValue(i)][:flag] |= CC.IR_FLAG_REFINED
347+
end
343348
end
344349
end
345350

src/stage1/compiler_utils.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Utilities that should probably go into Core.Compiler
2-
using Core.Compiler: CFG, BasicBlock, BBIdxIter
2+
using Core.Compiler: IRCode, CFG, BasicBlock, BBIdxIter
33

44
function Base.push!(cfg::CFG, bb::BasicBlock)
55
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
@@ -12,25 +12,22 @@ Base.getindex(ir::IRCode, ssa::SSAValue) =
1212

1313
Base.copy(ir::IRCode) = Core.Compiler.copy(ir)
1414

15-
function Core.Compiler.NewInstruction(node)
16-
Core.Compiler.NewInstruction(node, Any)
17-
end
15+
Core.Compiler.NewInstruction(@nospecialize node) =
16+
NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED)
1817

19-
function Base.setproperty!(x::Core.Compiler.Instruction, f::Symbol, v)
18+
Base.setproperty!(x::Core.Compiler.Instruction, f::Symbol, v) =
2019
Core.Compiler.setindex!(x, v, f)
21-
end
2220

23-
function Base.getproperty(x::Core.Compiler.Instruction, f::Symbol)
21+
Base.getproperty(x::Core.Compiler.Instruction, f::Symbol) =
2422
Core.Compiler.getindex(x, f)
25-
end
2623

27-
function Base.setindex!(ir::Core.Compiler.IRCode, ni::NewInstruction, i::Int)
24+
function Base.setindex!(ir::IRCode, ni::NewInstruction, i::Int)
2825
stmt = ir.stmts[i]
2926
stmt.inst = ni.stmt
3027
stmt.type = ni.type
3128
stmt.flag = something(ni.flag, 0) # fixes 1.9?
3229
stmt.line = something(ni.line, 0)
33-
ni
30+
return ni
3431
end
3532

3633
function Base.push!(ir::IRCode, ni::NewInstruction)
@@ -67,14 +64,14 @@ end
6764

6865

6966
"""
70-
find_end_of_phi_block(ir, start_search_idx)
67+
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)
7168
7269
Finds the last index within the same basic block, on or after the `start_search_idx` which is not within a phi block.
7370
A phi-block is a run on PhiNodes or nothings that must be the first statements within the basic block.
7471
7572
If `start_search_idx` is not within a phi block to begin with, then just returns `start_search_idx`
7673
"""
77-
function find_end_of_phi_block(ir, start_search_idx::Int)
74+
function find_end_of_phi_block(ir::IRCode, start_search_idx::Int)
7875
# Short-cut for early exit:
7976
stmt = ir.stmts[start_search_idx][:inst]
8077
stmt !== nothing && !isa(stmt, PhiNode) && return start_search_idx
@@ -90,8 +87,9 @@ function find_end_of_phi_block(ir, start_search_idx::Int)
9087
return end_search_idx
9188
end
9289

93-
function replace_call!(ir, idx::SSAValue, new_call)
90+
function replace_call!(ir::IRCode, idx::SSAValue, new_call::Expr)
9491
ir[idx][:inst] = new_call
9592
ir[idx][:type] = Any
9693
ir[idx][:info] = CC.NoCallInfo()
94+
ir[idx][:flag] = CC.IR_FLAG_REFINED
9795
end

src/stage2/forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
6060
if order == 0
6161
return
6262
end
63-
nr = insert_node!(ir, ssa, NewInstruction(Expr(:call, getindex, stmt.val, TaylorTangentIndex(order)), Any))
63+
nr = insert_node!(ir, ssa, NewInstruction(Expr(:call, getindex, stmt.val, TaylorTangentIndex(order))))
6464
inst[:inst] = ReturnNode(nr)
6565
elseif is_known_invoke_or_call(stmt, dont_use_ddt_intrinsic, ir)
6666
arg = maparg(stmt.args[end], ssa, order+1)

0 commit comments

Comments
 (0)