Skip to content

Commit e308c82

Browse files
authored
Merge pull request #180 from JuliaDiff/sf/replace_call
Use `replace_call!()` to replace `Expr(:call, ...)` values
2 parents dc3c60e + 077de69 commit e308c82

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

src/codegen/forward_demand.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I
151151
frule_result = insert_node!(ir, ssa, NewInstruction(
152152
frule_call, frule_rt, info.frule_call.info, inst[:line],
153153
frule_flag))
154-
ir[ssa][:inst] = Expr(:call, GlobalRef(Core, :getfield), frule_result, 1)
154+
replace_call!(ir, ssa, Expr(:call, GlobalRef(Core, :getfield), frule_result, 1))
155155
Δssa = insert_node!(ir, ssa, NewInstruction(
156156
Expr(:call, GlobalRef(Core, :getfield), frule_result, 2), CC.getfield_tfunc(CC.typeinf_lattice(interp), frule_rt, Const(2))), #=attach_after=#true)
157157
return Δssa
@@ -285,15 +285,13 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
285285
newargs = map(stmt.args[2:end]) do @nospecialize arg
286286
maparg(arg, SSAValue(ssa), order)
287287
end
288-
inst[:inst] = Expr(:call, ∂☆{order}(), newargs...)
289-
inst[:type] = Any
288+
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order}(), newargs...))
290289
elseif isexpr(stmt, :call) || isexpr(stmt, :new)
291290
newargs = map(stmt.args) do @nospecialize arg
292291
maparg(arg, SSAValue(ssa), order)
293292
end
294293
f = isexpr(stmt, :call) ? ∂☆{order}() : ∂☆new{order}()
295-
inst[:inst] = Expr(:call, f, newargs...)
296-
inst[:type] = Any
294+
replace_call!(ir, SSAValue(ssa), Expr(:call, f, newargs...))
297295
elseif isa(stmt, PiNode)
298296
# TODO: New PiNode that discriminates based on primal?
299297
inst[:inst] = maparg(stmt.val, SSAValue(ssa), order)
@@ -304,8 +302,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
304302
stmt = insert_node!(ir, ssa, NewInstruction(inst))
305303
end
306304

307-
inst[:inst] = Expr(:call, ZeroBundle{order}, stmt)
308-
inst[:type] = Any
305+
replace_call!(ir, SSAValue(ssa), Expr(:call, ZeroBundle{order}, stmt))
309306
elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode)
310307
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
311308
inst[:type] = Any

src/stage1/compiler_utils.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,10 @@ function find_end_of_phi_block(ir, start_search_idx::Int)
8989
stmt !== nothing && !isa(stmt, PhiNode) && return idx
9090
end
9191
return end_search_idx
92-
end
92+
end
93+
94+
function replace_call!(ir, idx::SSAValue, new_call)
95+
ir[idx][:inst] = new_call
96+
ir[idx][:type] = Any
97+
ir[idx][:info] = CC.NoCallInfo()
98+
end

0 commit comments

Comments
 (0)