Skip to content

Commit dc3c60e

Browse files
authored
Pass maparg to transform (#178)
In case the custom transform needs to get its argument of a particular order.
1 parent ea2e9f5 commit dc3c60e

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

src/codegen/forward_demand.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
254254
return arg
255255
elseif isa(arg, Argument)
256256
# TODO: Should we remember whether the callbacks wanted the arg?
257-
return transform!(ir, arg, order)
257+
return transform!(ir, arg, order, maparg)
258258
elseif isa(arg, GlobalRef)
259259
@assert isconst(arg)
260260
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
@@ -277,7 +277,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
277277
continue
278278
end
279279
if custom
280-
transform!(ir, SSAValue(ssa), order)
280+
transform!(ir, SSAValue(ssa), order, maparg)
281281
else
282282
inst = ir[SSAValue(ssa)]
283283
stmt = inst[:inst]

src/stage2/forward.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
3333
end
3434
end
3535

36-
function transform!(ir::IRCode, ssa::SSAValue, _)
36+
function transform!(ir::IRCode, ssa::SSAValue, _, _)
3737
inst = ir[ssa]
3838
stmt = inst[:inst]
3939
if isa(stmt, ReturnNode)
@@ -44,7 +44,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
4444
end
4545
end
4646

47-
function transform!(ir::IRCode, arg::Argument, _)
47+
function transform!(ir::IRCode, arg::Argument, _, _)
4848
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
4949
end
5050

test/forward_diff_no_inf.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
module forward_diff_no_inf
33
using Diffractor, Test
44
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
5-
identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa]
6-
function identity_transform!(ir, arg::Core.Argument, order)
5+
identity_transform!(ir, ssa::Core.SSAValue, order, _) = ir[ssa]
6+
function identity_transform!(ir, arg::Core.Argument, order, _)
77
return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any))
88
end
9-
9+
1010
@testset "Constructors in forward_diff_no_inf!" begin
1111
struct Bar148
1212
v
@@ -47,7 +47,7 @@ module forward_diff_no_inf
4747
end
4848
return x - a + b
4949
end
50-
50+
5151
input_ir = first(only(Base.code_ircode(phi_run, Tuple{Float64})))
5252
ir = copy(input_ir)
5353
#Workout where to diff to trigger error
@@ -57,11 +57,11 @@ module forward_diff_no_inf
5757
push!(diff_ssa, Core.SSAValue(idx))
5858
end
5959
end
60-
60+
6161
Diffractor.forward_diff_no_inf!(ir, diff_ssa .=> 1; transform! = identity_transform!)
6262
ir2 = Core.Compiler.compact!(ir)
6363
Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158)
6464
f = Core.OpaqueClosure(ir2; do_compile=false)
6565
@test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly
66-
end
66+
end
6767
end

0 commit comments

Comments
 (0)