Skip to content

Commit b81255d

Browse files
authored
Try to force better forward mode inference (#111)
By adding explicit recursion functions, and reducing the amount of inference work required for higher order forward mode.
1 parent 5b3a846 commit b81255d

File tree

6 files changed

+53
-28
lines changed

6 files changed

+53
-28
lines changed

src/Diffractor.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ include("tangent.jl")
1313
include("jet.jl")
1414

1515
include("stage1/generated.jl")
16-
include("stage1/termination.jl")
1716
include("stage1/forward.jl")
1817
include("stage1/recurse_fwd.jl")
1918
include("stage1/mixed.jl")
@@ -36,4 +35,6 @@ include("higher_fwd_rules.jl")
3635

3736
include("debugutils.jl")
3837

38+
include("stage1/termination.jl")
39+
3940
end

src/stage1/forward.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,27 @@ primal(z::ZeroTangent) = ZeroTangent()
1010

1111
first_partial(x) = partial(x, 1)
1212

13-
# TODO: Which version do we want in ChainRules?
14-
function my_frule(args::ATB{1}...)
15-
frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...)
16-
end
17-
18-
# Fast path for some hot cases
19-
my_frule(::ZeroBundle{1, typeof(frule)}, args::ATB{1}...) = nothing
20-
my_frule(::ZeroBundle{1, typeof(my_frule)}, args::ATB{1}...) = nothing
21-
22-
(::∂☆{N})(::ZeroBundle{N, typeof(my_frule)}, ::ZeroBundle{N, ZeroBundle{1, typeof(frule)}}, args::ATB{N}...) where {N} = ZeroBundle{N}(nothing)
23-
(::∂☆{N})(::ZeroBundle{N, typeof(my_frule)}, ::ZeroBundle{N, ZeroBundle{1, typeof(my_frule)}}, args::ATB{N}...) where {N} = ZeroBundle{N}(nothing)
24-
2513
shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} =
26-
UniformBundle{minus1(N), <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val)
14+
UniformBundle{minus1(N), <:Any}(UniformBundle{1, B}(b.primal, b.tangent.val),
15+
UniformBundle{1, U}(b.tangent.val, b.tangent.val))
2716

2817
function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
2918
# N.B: This depends on the special properties of the canonical tangent index order
19+
Base.@constprop :aggressive function _sdown(i::Int64)
20+
ExplicitTangentBundle{1}(partial(b, 2*i), (partial(b, 2*i+1),))
21+
end
3022
ExplicitTangentBundle{N-1}(
3123
ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)),
32-
ntuple(2^(N-1)-1) do i
33-
ExplicitTangentBundle{1}(partial(b, 2*i), (partial(b, 2*i+1),))
34-
end)
24+
ntuple(_sdown, 2^(N-1)-1))
3525
end
3626

3727
function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
28+
Base.@constprop :aggressive function _sdown(i::Int64)
29+
ExplicitTangentBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],))
30+
end
3831
TaylorBundle{N-1}(
3932
ExplicitTangentBundle{1}(b.primal, (b.tangent.coeffs[1],)),
40-
ntuple(N-1) do i
41-
ExplicitTangentBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],))
42-
end)
33+
ntuple(_sdown, N-1))
4334
end
4435

4536
function shuffle_down(b::CompositeBundle{N, B}) where {N, B}
@@ -106,10 +97,17 @@ end
10697
struct ∂☆internal{N}; end
10798
struct ∂☆shuffle{N}; end
10899

109-
shuffle_base(r) = TaylorBundle{1}(r[1], (r[2],))
100+
function shuffle_base(r)
101+
(primal, dual) = r
102+
if isa(dual, Union{NoTangent, ZeroTangent})
103+
UniformBundle{1}(primal, dual)
104+
else
105+
TaylorBundle{1}(primal, (dual,))
106+
end
107+
end
110108

111109
function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
112-
r = my_frule(args...)
110+
r = frule(#=DiffractorRuleConfig(),=# map(first_partial, args), map(primal, args)...)
113111
if r === nothing
114112
return ∂☆recurse{1}()(args...)
115113
else
@@ -125,7 +123,9 @@ end
125123

126124
function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
127125
∂☆p = ∂☆{minus1(N)}()
128-
∂☆p(ZeroBundle{minus1(N)}(my_frule), map(shuffle_down, args)...)
126+
downargs = map(shuffle_down, args)
127+
tupargs = ∂vararg{minus1(N)}()(map(first_partial, downargs)...)
128+
∂☆p(ZeroBundle{minus1(N)}(frule), #= ZeroBundle{minus1(N)}(DiffractorRuleConfig()), =# tupargs, map(primal, downargs)...)
129129
end
130130

131131
function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}

src/stage1/recurse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ function transform!(ci, meth, nargs, sparams, N)
261261
cfg = compute_basic_blocks(code)
262262
slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
263263
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
264-
slottypes = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
264+
slottypes = ci.slottypes === nothing ? nothing : UInt8[(Any for i = 1:2)..., ci.slottypes...]
265265

266266
meta = Expr[]
267267
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],

src/stage1/recurse_fwd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function perform_fwd_transform(@nospecialize(ff::Type{∂☆recurse{N}}), @nospe
5454
ci′.method_for_inference_limit_heuristics = match.method
5555
slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
5656
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
57-
slottypes = Any[(Any for i = 1:2)..., ci.slotflags...]
57+
slottypes = ci.slottypes === nothing ? nothing : Any[(Any for i = 1:2)..., ci.slottypes...]
5858
ci′.slotnames = slotnames
5959
ci′.slotflags = slotflags
6060
ci′.slottypes = slottypes

src/stage1/termination.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,30 @@ which(Tuple{∂⃖{N}, ∂⃖{1}, Vararg{Any}} where {N}).recursion_relation = f
4444
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
4545
end
4646

47+
for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N}, Vararg{Any}} where {N}, nothing, -1, typemax(UInt64))
48+
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
49+
# Recursion from a higher to a lower order is always allowed
50+
parent_order = parent_sig.parameters[1].parameters[1]
51+
child_order = new_sig.parameters[1].parameters[1]
52+
#@Core.Main.Base.show (parent_order, child_order)
53+
if parent_order > child_order
54+
return true
55+
end
56+
@show (parent_sig, new_sig)
57+
return false
58+
end
59+
end
60+
61+
for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆internal{N}, Vararg{Any}} where {N}, nothing, -1, typemax(UInt64))
62+
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
63+
return true
64+
end
65+
end
66+
67+
for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆{N}, Vararg{Any}} where {N}, nothing, -1, typemax(UInt64))
68+
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
69+
return true
70+
end
71+
end
72+
4773
end

test/stage2_fwd.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ module stage2_fwd
1414

1515
self_minus(a) = myminus(a, a)
1616
let self_minus′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}, 2)
17-
# TODO: The IR for this currently contains Union{Diffractor.TangentBundle{2, Float64, Diffractor.ExplicitTangent{Tuple{Float64, Float64, Float64}}}, Diffractor.TangentBundle{2, Float64, Diffractor.TaylorTangent{Tuple{Float64, Float64}}}}
18-
# We should have Diffractor be able to prove uniformity
19-
@test_broken isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64})
17+
@test isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64})
2018
@test self_minus′′(1.0) == 0.
2119
end
2220
end

0 commit comments

Comments
 (0)