Skip to content

Commit c0c77c8

Browse files
committed
WIP Eras mode
1 parent 5499312 commit c0c77c8

File tree

6 files changed

+130
-72
lines changed

6 files changed

+130
-72
lines changed

src/interface.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,28 @@ const ∂⃖¹ = ∂⃖{1}()
2525
(::Type{∂⃖})(args...) = ∂⃖¹(args...)
2626

2727
"""
28-
∂☆{N}
28+
∂☆{N,E}
2929
30-
∂☆{N} is the forward-mode AD functor of order `N`. A call
30+
∂☆{N} is the forward-mode AD functor of order `N` (An integer). A call
3131
`(::∂☆{N})(f, args...)` evaluating a function `f: A -> B` is lifted to its
3232
pushforward on the N-th order tangent bundle `f⋆: Tⁿ A -> Tⁿ B`.
33+
34+
35+
!!!advanced "Eras Mode"
36+
E (a bool, default false) is for Eras mode. In Eras mode, we are Taylor or bust.
37+
Normally if a particular derivative can not be represented as a `TaylorBundle`
38+
we fall back and represent it as a `ExplictTangentBundle`.
39+
However, in Eras mode we error if it can't be represented as a TaylorBundle.
40+
In general, this is not wanted since it often will break nested AD.
41+
But in the cases it doesn't its really fast, since it means we can rewrite nested AD
42+
as Taylor-mode AD (plus its more type stable).
43+
To be safe in Eras mode, it is sufficient, but not necessary, to be doing nested AD with
44+
respect to the same variable. It also works in other cases where (likely by problem construction)
45+
ADing with respect to a second variable happens to result in something that can be represented
46+
with a `TaylorBundle` also. (You need your different partials to happen to be exactly equal).
3347
"""
34-
struct ∂☆{N}; end
48+
struct ∂☆{N, E}; end
49+
∂☆{N}() where N = ∂☆{N,false}() # default to not using Era mode
3550
const ∂☆¹ = ∂☆{1}()
3651

3752
"""

src/stage1/broadcast.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@ using Base.Broadcast
22
using Base.Broadcast: broadcasted, Broadcasted
33

44
# Forward mode broadcast rule
5-
struct FwdBroadcast{N, T<:AbstractTangentBundle{N}}
5+
struct FwdBroadcast{N, E, T<:AbstractTangentBundle{N}}
66
f::T
77
end
8-
(f::FwdBroadcast{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
8+
FwdItFwdBroadcastrate{E}(f::T) where {N, E, T<:AbstractTangentBundle{N}} = FwdBroadcast{N,E,T}(f)
9+
10+
(f::FwdBroadcast{N,E})(args::AbstractTangentBundle{N}...) where {N,E} = ∂☆{N,E}()(f.f, args...)
911

1012
n_getfield(∂ₙ::∂☆{N}, b::ATB{N}, x::Union{Symbol, Int}) where {N} = ∂ₙ(ZeroBundle{N}(getfield), b, ZeroBundle{N}(x))
1113

12-
function (∂ₙ::∂☆{N})(zc::AbstractZeroBundle{N, typeof(copy)},
13-
bc::ATB{N, <:Broadcasted}) where {N}
14+
function (∂ₙ::∂☆{N,E})(zc::AbstractZeroBundle{N, typeof(copy)},
15+
bc::ATB{N, <:Broadcasted}) where {N,E}
1416
bc = ∂ₙ(ZeroBundle{N}(Broadcast.flatten), bc)
1517
args = n_getfield(∂ₙ, bc, :args)
1618
r = copy(Broadcasted(
17-
FwdMap(n_getfield(∂ₙ, bc, :f)),
19+
FwdMap{E}(n_getfield(∂ₙ, bc, :f)),
1820
ntuple(length(primal(args))) do i
1921
val = n_getfield(∂ₙ, args, i)
2022
if ndims(primal(val)) == 0

src/stage1/forward.jl

Lines changed: 94 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -33,38 +33,68 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
3333
ntuple(_sdown, N-1))
3434
end
3535

36-
@noinline check_taylor(z₁, z₂) = @assert(z₁ == z₂, "$z₁ == $z₂")
36+
struct TaylorRequired
37+
order
38+
z₁
39+
z₂
40+
end
41+
function Base.showerror(io::IO, err)
42+
order_str1 = order_str(err.order)
43+
print(io, "In Eras mode all higher order derivatives must be taylor, but encountered one where the taylor requirement z₁ == z₂ was not met.")
44+
println(is, "derivative on $order_str1 path: z₁ = ", err.z₁)
45+
println(is, "$order_str1 on the derivative path: z₂ = ", err.z₂)
46+
end
47+
48+
function order_str(order::Integer)
49+
@assert order>=0
50+
if order == 0
51+
"primal"
52+
elseif order == 1
53+
"derivative"
54+
elseif order == 2
55+
"2nd derivative"
56+
elseif order == 3
57+
"3rd derivative"
58+
else
59+
"$(order)th derivative"
60+
end
61+
end
3762

38-
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2}
63+
"finds the lowerest order derivative that is not taylor compatible, or returns -1 if all compatible"
64+
@noinline function find_taylor_incompatibility(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
65+
partial(r, 1)[1] == primal(r)[2] || return 0
66+
for i in 1:(N-1)
67+
partial(r, i+1)[1] == partial(r, i)[2] || return i
68+
end
69+
return -1 # all compatible
70+
end
71+
72+
function taylor_failure_values(r::TaylorBundle{<:Any, Tuple{Any,Any}}, fail_order)
73+
fail_order == 0 && return partial(r,1)[1], primal(r)[2]
74+
return partial(r, i+1)[1], partial(r, i)[2]
75+
end
76+
77+
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}, ::Val{taylor_or_bust}) where {B1,B2, taylor_or_bust}
3978
z₀ = primal(r)[1]
4079
z₁ = partial(r, 1)[1]
4180
z₂ = primal(r)[2]
4281
z₁₂ = partial(r, 1)[2]
43-
if true
44-
check_taylor(z₁, z₂)
82+
83+
taylor_fail_order = find_taylor_incompatibility(r)
84+
if taylor_fail_order < 0
4585
return TaylorBundle{2}(z₀, (z₁, z₁₂))
86+
elseif taylor_or_bust
87+
@assert taylor_fail_order == 0 # can't be higher
88+
throw(TaylorRequired(taylor_fail_order, z₁, z₂))
4689
else
4790
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
4891
end
4992
end
5093

51-
function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
52-
primal(b) === a[TaylorTangentIndex(1)] || return false
53-
return all(1:(N-1)) do i
54-
b[TaylorTangentIndex(i)] === a[TaylorTangentIndex(i+1)]
55-
end
56-
end
57-
58-
function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
59-
partial(r, 1)[1] == primal(r)[2] || return false
60-
return all(1:N-1) do i
61-
partial(r, i+1)[1] == partial(r, i)[2]
62-
end
63-
end
64-
function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
94+
function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}, ::Val{taylor_or_bust}) where {N, B1,B2, taylor_or_bust}
6595
the_primal = primal(r)[1]
66-
if true
67-
@assert taylor_compatible(r)
96+
taylor_fail_order = find_taylor_incompatibility(r)
97+
if taylor_fail_order(r) < 0
6898
the_partials = ntuple(N+1) do i
6999
if i <= N
70100
partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
@@ -73,6 +103,9 @@ function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
73103
end
74104
end
75105
return TaylorBundle{N+1}(the_primal, the_partials)
106+
elseif taylor_or_bust
107+
@assert taylor_fail_order < N
108+
throw(TaylorRequired(taylor_fail_order, taylor_failure_values(r, taylor_fail_order)...))
76109
else
77110
#XXX: am dubious of the correctness of this
78111
a_partials = ntuple(i->partial(r, i)[1], N)
@@ -83,7 +116,7 @@ function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
83116
end
84117

85118

86-
function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
119+
function shuffle_up(r::UniformBundle{N, B, U}, _::Val) where {N, B, U}
87120
(a, b) = primal(r)
88121
if r.tangent.val === b
89122
u = b
@@ -94,7 +127,7 @@ function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
94127
end
95128
UniformBundle{N+1}(a, u)
96129
end
97-
@ChainRulesCore.non_differentiable shuffle_up(r::UniformBundle)
130+
@ChainRulesCore.non_differentiable shuffle_up(r::UniformBundle, ::Val)
98131

99132

100133
function shuffle_up_bundle(r::Diffractor.TangentBundle{1, B}) where {B<:ATB{1}}
@@ -124,9 +157,6 @@ function shuffle_up_bundle(r::UniformBundle{1, <:UniformBundle{1, B, U}}) where
124157
return UniformBundle{2, B, U}(primal(primal(r)))
125158
end
126159

127-
function shuffle_down_bundle(b::ExplicitTangentBundle{N, B}) where {N, B}
128-
error("TODO")
129-
end
130160

131161
function shuffle_down_bundle(b::TaylorBundle{2, B}) where {B}
132162
z₀ = primal(b)
@@ -135,8 +165,10 @@ function shuffle_down_bundle(b::TaylorBundle{2, B}) where {B}
135165
TaylorBundle{1}(TaylorBundle{1}(z₀, (z₁,)), (TaylorBundle{1}(z₁, (z₁₂,)),))
136166
end
137167

138-
struct ∂☆internal{N}; end
139-
struct ∂☆recurse{N}; end
168+
#N order, this should be a positive Int
169+
#E eras mode, this controls if we should Error if it isn't Taylor. This should be a Bool
170+
struct ∂☆internal{N, E}; end
171+
struct ∂☆recurse{N, E}; end
140172
struct ∂☆shuffle{N}; end
141173

142174
function shuffle_base(r)
@@ -151,26 +183,28 @@ function shuffle_base(r)
151183
end
152184
end
153185

154-
function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
155-
r = _frule(map(first_partial, args), map(primal, args)...)
186+
function (::∂☆internal{1, E})(args::AbstractTangentBundle{1}...) where E
187+
r = _frule(Val{E}(), map(first_partial, args), map(primal, args)...)
156188
if r === nothing
157-
return ∂☆recurse{1}()(args...)
189+
return ∂☆recurse{1, E}()(args...)
158190
else
159191
return shuffle_base(r)
160192
end
161193
end
162194

163-
_frule(partials, primals...) = frule(#== DiffractorRuleConfig(), ==# partials, primals...)
164-
function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...)
195+
# TODO: workout why enabling calling back into AD in Eras mode causes type instability
196+
_frule(::Val{true}, partials, primals...) = frule(partials, primals...)
197+
_frule(::Val{false}, partials, primals...) = frule(DiffractorRuleConfig(), partials, primals...)
198+
function _frule(::Any, ::NTuple{<:Any, AbstractZero}, f, primal_args...)
165199
# frules are linear in partials, so zero maps to zero, no need to evaluate the frule
166200
# If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either
167201
r = f(primal_args...)
168202
return r, zero_tangent(r)
169203
end
170204

171205
function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
172-
bundles = map(bundle, args, partials)
173-
result = ∂☆internal{1}()(bundles...)
206+
bundles = map(bundle, partials, args)
207+
result = ∂☆internal{1,false}()(bundles...)
174208
primal(result), first_partial(result)
175209
end
176210

@@ -194,21 +228,21 @@ function (::∂☆internal{1})(f::AbstractZeroBundle{1}, args::AbstractZeroBundl
194228
return zero_bundle{1}()(f_v(args_v...))
195229
end
196230

197-
function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}
231+
function (::∂☆internal{N, E})(args::AbstractTangentBundle{N}...) where {N, E}
198232
r = ∂☆shuffle{N}()(args...)
199233
if primal(r) === nothing
200-
return ∂☆recurse{N}()(args...)
234+
return ∂☆recurse{N, E}()(args...)
201235
else
202-
return shuffle_up(r)
236+
return shuffle_up(r, Val{E}())
203237
end
204238
end
205239

206240
# TODO: Generalize to N,M
207-
@inline function (::∂☆{1})(rec::AbstractZeroBundle{1, ∂☆recurse{1}}, args::ATB{1}...)
208-
return shuffle_down_bundle(∂☆recurse{2}()(map(shuffle_up_bundle, args)...))
241+
@inline function (::∂☆{1,E})(rec::AbstractZeroBundle{1, ∂☆recurse{1, E}}, args::ATB{1}...) where E
242+
return shuffle_down_bundle(∂☆recurse{2,E}()(map(shuffle_up_bundle, args)...))
209243
end
210244

211-
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)
245+
(::∂☆{N,E})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N,E}()(args...)
212246

213247
# Special case rules for performance
214248
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
@@ -252,22 +286,23 @@ function (::∂☆{N})(f::ATB{N, typeof(tuple)}, args::AbstractZeroBundle{N}...)
252286
ZeroBundle{N}(map(primal, args)) # special fast case
253287
end
254288

255-
struct FwdMap{N, T<:AbstractTangentBundle{N}}
289+
struct FwdMap{N, E, T<:AbstractTangentBundle{N}}
256290
f::T
257291
end
258-
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
292+
FwdMap{E}(f::T) where {N, E, T<:AbstractTangentBundle{N}} = FwdMap{N,E,T}(f)
293+
(f::FwdMap{N,E})(args::AbstractTangentBundle{N}...) where {N,E} = ∂☆{N,E}()(f.f, args...)
259294

260-
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
261-
∂vararg{N}()(map(FwdMap(f), destructure(tup))...)
295+
function (::∂☆{N,E})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N,E}
296+
∂vararg{N}()(map(FwdMap{E}(f), destructure(tup))...)
262297
end
263298

264-
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
299+
function (::∂☆{N,E})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N,E}
265300
# TODO: This could do an inplace map! to avoid the extra rebundling
266-
rebundle(map(FwdMap(f), map(unbundle, args)...))
301+
rebundle(map(FwdMap{E}(f), map(unbundle, args)...))
267302
end
268303

269-
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N}
270-
∂☆recurse{N}()(ZeroBundle{N, typeof(map)}(map), f, args...)
304+
function (::∂☆{N,E})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N, E}
305+
∂☆recurse{N,E}()(ZeroBundle{N, typeof(map)}(map), f, args...)
271306
end
272307

273308

@@ -279,29 +314,29 @@ function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N
279314
Core.ifelse(arg.primal, args...)
280315
end
281316

282-
struct FwdIterate{N, T<:AbstractTangentBundle{N}}
317+
struct FwdIterate{N, E, T<:AbstractTangentBundle{N}}
283318
f::T
284319
end
285-
function (f::FwdIterate)(arg::ATB{N}) where {N}
286-
r = ∂☆{N}()(f.f, arg)
320+
FwdIterate{E}(f::T) where {N, E, T<:AbstractTangentBundle{N}} = FwdIterate{N,E,T}(f)
321+
function (f::FwdIterate{N,E})(arg::ATB{N}) where {N,E}
322+
r = ∂☆{N,E}()(f.f, arg)
287323
# `primal(r) === nothing` would work, but doesn't create `Conditional` in inference
288324
isa(r, ATB{N, Nothing}) && return nothing
289-
(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(1)),
290-
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
325+
(∂☆{N,E}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(1)),
326+
primal(∂☆{N,E}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
291327
end
292-
@Base.constprop :aggressive function (f::FwdIterate)(arg::ATB{N}, st) where {N}
293-
r = ∂☆{N}()(f.f, arg, ZeroBundle{N}(st))
328+
@Base.constprop :aggressive function (f::FwdIterate{N,E})(arg::ATB{N}, st) where {N,E}
329+
r = ∂☆{N,E}()(f.f, arg, ZeroBundle{N}(st))
294330
# `primal(r) === nothing` would work, but doesn't create `Conditional` in inference
295331
isa(r, ATB{N, Nothing}) && return nothing
296332
(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(1)),
297-
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
333+
primal(∂☆{N,E}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
298334
end
299335

300-
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
301-
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
336+
function (this::∂☆{N,E})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N,E}
337+
Core._apply_iterate(FwdIterate{E}(iterate), this, (f,), args...)
302338
end
303339

304-
305340
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N}
306341
r = iterate(destructure(t))
307342
r === nothing && return ZeroBundle{N}(nothing)

src/stage1/mixed.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ function shuffle_down_frule(∂☆p, my_frule, args...)
4343
∂☆p(my_frule, map(shuffle_down, args)...)
4444
end
4545

46-
function (this::∂⃖{N})(::∂☆internal{M}, args::AbstractTangentBundle{1}...) where {N, M}
46+
function (this::∂⃖{N})(::∂☆internal{M}, args::AbstractTangentBundle{1}...) where {N, M, E}
4747
r = this(∂☆shuffle{N}(), args...)
4848
if primal(r) === nothing
4949
return this(∂☆recurse{N}(), args...)
5050
else
51-
z, ∂z = this(shuffle_up, r)
51+
z, ∂z = this(v->shuffle_up(v, Val(false)), r) # never taylor_or_bust for mixed mode
5252
return z, ∂⃖composeOdd{1, c_order(N)}(∂r, ∂z)
5353
end
5454
end

src/stage1/recurse_fwd.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function fwd_transform(ci::CodeInfo, args...)
7878
return newci
7979
end
8080

81-
function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)
81+
function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
8282
new_code = Any[]
8383
@static if VERSION v"1.12.0-DEV.173"
8484
debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(ci.code))
@@ -112,7 +112,7 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)
112112
args = map(stmt.args) do stmt
113113
emit!(mapstmt!(stmt))
114114
end
115-
return Expr(:call, ∂☆{N}(), args...)
115+
return Expr(:call, ∂☆{N, E}(), args...)
116116
elseif isexpr(stmt, :new)
117117
args = map(stmt.args) do stmt
118118
emit!(mapstmt!(stmt))
@@ -122,7 +122,7 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)
122122
args = map(stmt.args) do stmt
123123
emit!(mapstmt!(stmt))
124124
end
125-
return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
125+
return Expr(:call, Core._apply_iterate, FwdIterate{E}(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
126126
elseif isa(stmt, SSAValue)
127127
return SSAValue(ssa_mapping[stmt.id])
128128
elseif isa(stmt, Core.SlotNumber)

src/stage1/termination.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆internal{N}, Vara
6363
end
6464
end
6565

66+
for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆internal{N, E}, Vararg{Any}} where {N, E}, nothing, -1, get_world_counter())
67+
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
68+
return true
69+
end
70+
end
71+
6672
for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆{N}, Vararg{Any}} where {N}, nothing, -1, get_world_counter())
6773
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
6874
return true

0 commit comments

Comments
 (0)