Skip to content

Commit 24f047c

Browse files
authored
Allow visit_custom recursion to change the order (#179)
* Allow visit_custom recursion to change the order To implement order-changing intrinsics. * Add ddt test
1 parent 3d5bee0 commit 24f047c

File tree

5 files changed

+58
-11
lines changed

5 files changed

+58
-11
lines changed

src/codegen/forward_demand.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vecto
165165
ssa_orders[ssa.id] = order => ssa_orders[ssa.id][2]
166166
inst = ir[ssa]
167167
stmt = inst[:inst]
168-
recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!)
168+
recurse(@nospecialize(val), new_order=order) = forward_visit!(ir, val, new_order, ssa_orders, visit_custom!)
169169
if visit_custom!(ir, ssa, order, recurse)
170170
ssa_orders[ssa.id] = order => true
171171
return
@@ -220,10 +220,15 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
220220
visit_custom! = (@nospecialize args...)->false,
221221
transform! = (@nospecialize args...)->error())
222222
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
223-
ssa_orders = [0=>false for i = 1:length(ir.stmts)]
223+
ssa_orders = [-1=>false for i = 1:length(ir.stmts)]
224224
for (ssa, order) in to_diff
225225
forward_visit!(ir, ssa, order, ssa_orders, visit_custom!)
226226
end
227+
for (ssa, (order, custom)) in enumerate(ssa_orders)
228+
if order == -1
229+
ssa_orders[ssa] = 0 => custom
230+
end
231+
end
227232

228233
truncation_map = Dict{Pair{SSAValue, Int}, SSAValue}()
229234

@@ -266,7 +271,9 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
266271
end
267272

268273
for (ssa, (order, custom)) in enumerate(ssa_orders)
269-
if order == 0
274+
if custom
275+
transform!(ir, SSAValue(ssa), order, maparg)
276+
elseif order == 0
270277
inst = ir[SSAValue(ssa)]
271278
stmt = inst[:inst]
272279
urs = userefs(stmt)
@@ -275,9 +282,6 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
275282
end
276283
inst[:inst] = urs[]
277284
continue
278-
end
279-
if custom
280-
transform!(ir, SSAValue(ssa), order, maparg)
281285
else
282286
inst = ir[SSAValue(ssa)]
283287
stmt = inst[:inst]

src/stage1/recurse_fwd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
4747
return generate_lambda_ex(world, source,
4848
Core.svec(:ff, :args), Core.svec(), :(∂☆builtin(args)))
4949
end
50-
50+
5151
mthds = Base._methods_by_ftype(sig, -1, world)
5252
if mthds === nothing || length(mthds) != 1
5353
# Core.println("[perform_fwd_transform] ", sig, " => ", mthds)

src/stage2/forward.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
11
using .CC: compact!
22

3+
function is_known_invoke_or_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,IncrementalCompact})
4+
return is_known_invoke(x, func, ir) || CC.is_known_call(x, func, ir)
5+
end
6+
7+
function is_known_invoke(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,IncrementalCompact})
8+
isexpr(x, :invoke) || return false
9+
ft = argextype(x.args[2], ir)
10+
return singleton_type(ft) === func
11+
end
12+
13+
@noinline function dont_use_ddt_intrinsic(x::Float64)
14+
if Base.inferencebarrier(true)
15+
error("Intrinsic not transformed")
16+
end
17+
return Base.inferencebarrier(0.0)::Float64
18+
end
19+
320
# Engineering entry point for the 2nd-order forward AD functionality. This is
421
# unlikely to be the actual interface. For now, it is used for testing.
522
function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
@@ -28,24 +45,41 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
2845
if isa(stmt, ReturnNode)
2946
recurse(stmt.val)
3047
return true
48+
elseif is_known_invoke_or_call(stmt, dont_use_ddt_intrinsic, ir)
49+
recurse(stmt.args[end], order+1)
50+
return true
3151
else
3252
return false
3353
end
3454
end
3555

36-
function transform!(ir::IRCode, ssa::SSAValue, _, _)
56+
function transform!(ir::IRCode, ssa::SSAValue, _, maparg)
3757
inst = ir[ssa]
3858
stmt = inst[:inst]
3959
if isa(stmt, ReturnNode)
60+
if order == 0
61+
return
62+
end
4063
nr = insert_node!(ir, ssa, NewInstruction(Expr(:call, getindex, stmt.val, TaylorTangentIndex(order)), Any))
4164
inst[:inst] = ReturnNode(nr)
65+
elseif is_known_invoke_or_call(stmt, dont_use_ddt_intrinsic, ir)
66+
arg = maparg(stmt.args[end], ssa, order+1)
67+
if order > 0
68+
replace_call!(ir, ssa, Expr(:call, error, "Only order 0 implemented here"))
69+
else
70+
replace_call!(ir, ssa, Expr(:call, getindex, arg, TaylorTangentIndex(1)))
71+
end
4272
else
4373
error()
4474
end
4575
end
4676

47-
function transform!(ir::IRCode, arg::Argument, _, _)
48-
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
77+
function transform!(ir::IRCode, arg::Argument, order, _)
78+
if order == 0
79+
return arg
80+
else
81+
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
82+
end
4983
end
5084

5185
ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!)

src/tangent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ end
283283
Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val
284284

285285
"""
286-
CompositeBundle{N, B <: Tuple}
286+
CompositeBundle{N, B, B <: Tuple}
287287
288288
Represents the tagent bundle where the base space is some tuple or struct type.
289289
Mathematically, this tangent bundle is the product bundle of the individual

test/stage2_fwd.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,13 @@ module stage2_fwd
5959
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
6060
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10f0, (1.0,)))
6161
end
62+
63+
@testset "ddt intrinsic" begin
64+
function my_cos_ddt(x)
65+
return Diffractor.dont_use_ddt_intrinsic(sin(x))
66+
end
67+
let my_cos_ddt_transformed = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(my_cos_ddt), Float64}, 0)
68+
@test my_cos_ddt_transformed(1.0) == cos(1.0)
69+
end
70+
end
6271
end

0 commit comments

Comments
 (0)