@@ -33,38 +33,68 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
33
33
ntuple (_sdown, N- 1 ))
34
34
end
35
35
36
- @noinline check_taylor (z₁, z₂) = @assert (z₁ == z₂, " $z₁ == $z₂ " )
36
+ struct TaylorRequired
37
+ order
38
+ z₁
39
+ z₂
40
+ end
41
+ function Base. showerror (io:: IO , err)
42
+ order_str1 = order_str (err. order)
43
+ print (io, " In Eras mode all higher order derivatives must be taylor, but encountered one where the taylor requirement z₁ == z₂ was not met." )
44
+ println (is, " derivative on $order_str1 path: z₁ = " , err. z₁)
45
+ println (is, " $order_str1 on the derivative path: z₂ = " , err. z₂)
46
+ end
47
+
48
+ function order_str (order:: Integer )
49
+ @assert order>= 0
50
+ if order == 0
51
+ " primal"
52
+ elseif order == 1
53
+ " derivative"
54
+ elseif order == 2
55
+ " 2nd derivative"
56
+ elseif order == 3
57
+ " 3rd derivative"
58
+ else
59
+ " $(order) th derivative"
60
+ end
61
+ end
37
62
38
- function shuffle_up (r:: TaylorBundle{1, Tuple{B1,B2}} ) where {B1,B2}
63
+ " finds the lowerest order derivative that is not taylor compatible, or returns -1 if all compatible"
64
+ @noinline function find_taylor_incompatibility (r:: TaylorBundle{N, Tuple{B1,B2}} ) where {N, B1,B2}
65
+ partial (r, 1 )[1 ] == primal (r)[2 ] || return 0
66
+ for i in 1 : (N- 1 )
67
+ partial (r, i+ 1 )[1 ] == partial (r, i)[2 ] || return i
68
+ end
69
+ return - 1 # all compatible
70
+ end
71
+
72
+ function taylor_failure_values (r:: TaylorBundle{<:Any, Tuple{Any,Any}} , fail_order)
73
+ fail_order == 0 && return partial (r,1 )[1 ], primal (r)[2 ]
74
+ return partial (r, i+ 1 )[1 ], partial (r, i)[2 ]
75
+ end
76
+
77
+ function shuffle_up (r:: TaylorBundle{1, Tuple{B1,B2}} , :: Val{taylor_or_bust} ) where {B1,B2, taylor_or_bust}
39
78
z₀ = primal (r)[1 ]
40
79
z₁ = partial (r, 1 )[1 ]
41
80
z₂ = primal (r)[2 ]
42
81
z₁₂ = partial (r, 1 )[2 ]
43
- if true
44
- check_taylor (z₁, z₂)
82
+
83
+ taylor_fail_order = find_taylor_incompatibility (r)
84
+ if taylor_fail_order < 0
45
85
return TaylorBundle {2} (z₀, (z₁, z₁₂))
86
+ elseif taylor_or_bust
87
+ @assert taylor_fail_order == 0 # can't be higher
88
+ throw (TaylorRequired (taylor_fail_order, z₁, z₂))
46
89
else
47
90
return ExplicitTangentBundle {2} (z₀, (z₁, z₂, z₁₂))
48
91
end
49
92
end
50
93
51
- function taylor_compatible (a:: ATB{N} , b:: ATB{N} ) where {N}
52
- primal (b) === a[TaylorTangentIndex (1 )] || return false
53
- return all (1 : (N- 1 )) do i
54
- b[TaylorTangentIndex (i)] === a[TaylorTangentIndex (i+ 1 )]
55
- end
56
- end
57
-
58
- function taylor_compatible (r:: TaylorBundle{N, Tuple{B1,B2}} ) where {N, B1,B2}
59
- partial (r, 1 )[1 ] == primal (r)[2 ] || return false
60
- return all (1 : N- 1 ) do i
61
- partial (r, i+ 1 )[1 ] == partial (r, i)[2 ]
62
- end
63
- end
64
- function shuffle_up (r:: TaylorBundle{N, Tuple{B1,B2}} ) where {N, B1,B2}
94
+ function shuffle_up (r:: TaylorBundle{N, Tuple{B1,B2}} , :: Val{taylor_or_bust} ) where {N, B1,B2, taylor_or_bust}
65
95
the_primal = primal (r)[1 ]
66
- if true
67
- @assert taylor_compatible (r)
96
+ taylor_fail_order = find_taylor_incompatibility (r)
97
+ if taylor_fail_order (r) < 0
68
98
the_partials = ntuple (N+ 1 ) do i
69
99
if i <= N
70
100
partial (r, i)[1 ] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
@@ -73,6 +103,9 @@ function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
73
103
end
74
104
end
75
105
return TaylorBundle {N+1} (the_primal, the_partials)
106
+ elseif taylor_or_bust
107
+ @assert taylor_fail_order < N
108
+ throw (TaylorRequired (taylor_fail_order, taylor_failure_values (r, taylor_fail_order)... ))
76
109
else
77
110
# XXX : am dubious of the correctness of this
78
111
a_partials = ntuple (i-> partial (r, i)[1 ], N)
@@ -83,7 +116,7 @@ function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
83
116
end
84
117
85
118
86
- function shuffle_up (r:: UniformBundle{N, B, U} ) where {N, B, U}
119
+ function shuffle_up (r:: UniformBundle{N, B, U} , _ :: Val ) where {N, B, U}
87
120
(a, b) = primal (r)
88
121
if r. tangent. val === b
89
122
u = b
@@ -94,7 +127,7 @@ function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
94
127
end
95
128
UniformBundle {N+1} (a, u)
96
129
end
97
- @ChainRulesCore . non_differentiable shuffle_up (r:: UniformBundle )
130
+ @ChainRulesCore . non_differentiable shuffle_up (r:: UniformBundle , :: Val )
98
131
99
132
100
133
function shuffle_up_bundle (r:: Diffractor.TangentBundle{1, B} ) where {B<: ATB{1} }
@@ -124,9 +157,6 @@ function shuffle_up_bundle(r::UniformBundle{1, <:UniformBundle{1, B, U}}) where
124
157
return UniformBundle {2, B, U} (primal (primal (r)))
125
158
end
126
159
127
- function shuffle_down_bundle (b:: ExplicitTangentBundle{N, B} ) where {N, B}
128
- error (" TODO" )
129
- end
130
160
131
161
function shuffle_down_bundle (b:: TaylorBundle{2, B} ) where {B}
132
162
z₀ = primal (b)
@@ -135,8 +165,10 @@ function shuffle_down_bundle(b::TaylorBundle{2, B}) where {B}
135
165
TaylorBundle {1} (TaylorBundle {1} (z₀, (z₁,)), (TaylorBundle {1} (z₁, (z₁₂,)),))
136
166
end
137
167
138
- struct ∂☆internal{N}; end
139
- struct ∂☆recurse{N}; end
168
+ # N order, this should be a positive Int
169
+ # E eras mode, this controls if we should Error if it isn't Taylor. This should be a Bool
170
+ struct ∂☆internal{N, E}; end
171
+ struct ∂☆recurse{N, E}; end
140
172
struct ∂☆shuffle{N}; end
141
173
142
174
function shuffle_base (r)
@@ -151,26 +183,28 @@ function shuffle_base(r)
151
183
end
152
184
end
153
185
154
- function (:: ∂☆internal{1 })(args:: AbstractTangentBundle{1} ...)
155
- r = _frule (map (first_partial, args), map (primal, args)... )
186
+ function (:: ∂☆internal{1 , E })(args:: AbstractTangentBundle{1} ...) where E
187
+ r = _frule (Val {E} (), map (first_partial, args), map (primal, args)... )
156
188
if r === nothing
157
- return ∂☆recurse {1} ()(args... )
189
+ return ∂☆recurse {1, E } ()(args... )
158
190
else
159
191
return shuffle_base (r)
160
192
end
161
193
end
162
194
163
- _frule (partials, primals... ) = frule (#= = DiffractorRuleConfig(), ==# partials, primals... )
164
- function _frule (:: NTuple{<:Any, AbstractZero} , f, primal_args... )
195
+ # TODO : workout why enabling calling back into AD in Eras mode causes type instability
196
+ _frule (:: Val{true} , partials, primals... ) = frule (partials, primals... )
197
+ _frule (:: Val{false} , partials, primals... ) = frule (DiffractorRuleConfig (), partials, primals... )
198
+ function _frule (:: Any , :: NTuple{<:Any, AbstractZero} , f, primal_args... )
165
199
# frules are linear in partials, so zero maps to zero, no need to evaluate the frule
166
200
# If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either
167
201
r = f (primal_args... )
168
202
return r, zero_tangent (r)
169
203
end
170
204
171
205
function ChainRulesCore. frule_via_ad (:: DiffractorRuleConfig , partials, args... )
172
- bundles = map (bundle, args, partials )
173
- result = ∂☆internal {1} ()(bundles... )
206
+ bundles = map (bundle, partials, args )
207
+ result = ∂☆internal {1,false } ()(bundles... )
174
208
primal (result), first_partial (result)
175
209
end
176
210
@@ -194,21 +228,21 @@ function (::∂☆internal{1})(f::AbstractZeroBundle{1}, args::AbstractZeroBundl
194
228
return zero_bundle {1} ()(f_v (args_v... ))
195
229
end
196
230
197
- function (:: ∂☆internal{N})(args:: AbstractTangentBundle{N} ...) where {N}
231
+ function (:: ∂☆internal{N, E })(args:: AbstractTangentBundle{N} ...) where {N, E }
198
232
r = ∂☆shuffle {N} ()(args... )
199
233
if primal (r) === nothing
200
- return ∂☆recurse {N} ()(args... )
234
+ return ∂☆recurse {N, E } ()(args... )
201
235
else
202
- return shuffle_up (r)
236
+ return shuffle_up (r, Val {E} () )
203
237
end
204
238
end
205
239
206
240
# TODO : Generalize to N,M
207
- @inline function (:: ∂☆{1 })(rec:: AbstractZeroBundle{1, ∂☆recurse{1}} , args:: ATB{1} ...)
208
- return shuffle_down_bundle (∂☆recurse {2} ()(map (shuffle_up_bundle, args)... ))
241
+ @inline function (:: ∂☆{1 ,E })(rec:: AbstractZeroBundle{1, ∂☆recurse{1, E }} , args:: ATB{1} ...) where E
242
+ return shuffle_down_bundle (∂☆recurse {2,E } ()(map (shuffle_up_bundle, args)... ))
209
243
end
210
244
211
- (:: ∂☆{N})(args:: AbstractTangentBundle{N} ...) where {N} = ∂☆internal {N} ()(args... )
245
+ (:: ∂☆{N,E })(args:: AbstractTangentBundle{N} ...) where {N} = ∂☆internal {N,E } ()(args... )
212
246
213
247
# Special case rules for performance
214
248
@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: TangentBundle{N} , s:: AbstractTangentBundle{N} ) where {N}
@@ -252,22 +286,23 @@ function (::∂☆{N})(f::ATB{N, typeof(tuple)}, args::AbstractZeroBundle{N}...)
252
286
ZeroBundle {N} (map (primal, args)) # special fast case
253
287
end
254
288
255
- struct FwdMap{N, T<: AbstractTangentBundle{N} }
289
+ struct FwdMap{N, E, T<: AbstractTangentBundle{N} }
256
290
f:: T
257
291
end
258
- (f:: FwdMap{N} )(args:: AbstractTangentBundle{N} ...) where {N} = ∂☆ {N} ()(f. f, args... )
292
+ FwdMap {E} (f:: T ) where {N, E, T<: AbstractTangentBundle{N} } = FwdMap {N,E,T} (f)
293
+ (f:: FwdMap{N,E} )(args:: AbstractTangentBundle{N} ...) where {N,E} = ∂☆ {N,E} ()(f. f, args... )
259
294
260
- function (:: ∂☆{N})(:: AbstractZeroBundle{N, typeof(map)} , f:: ATB{N} , tup:: TaylorBundle{N, <:Tuple} ) where {N}
261
- ∂vararg {N} ()(map (FwdMap (f), destructure (tup))... )
295
+ function (:: ∂☆{N,E })(:: AbstractZeroBundle{N, typeof(map)} , f:: ATB{N} , tup:: TaylorBundle{N, <:Tuple} ) where {N,E }
296
+ ∂vararg {N} ()(map (FwdMap {E} (f), destructure (tup))... )
262
297
end
263
298
264
- function (:: ∂☆{N})(:: AbstractZeroBundle{N, typeof(map)} , f:: ATB{N} , args:: ATB{N, <:AbstractArray} ...) where {N}
299
+ function (:: ∂☆{N,E })(:: AbstractZeroBundle{N, typeof(map)} , f:: ATB{N} , args:: ATB{N, <:AbstractArray} ...) where {N,E }
265
300
# TODO : This could do an inplace map! to avoid the extra rebundling
266
- rebundle (map (FwdMap (f), map (unbundle, args)... ))
301
+ rebundle (map (FwdMap {E} (f), map (unbundle, args)... ))
267
302
end
268
303
269
- function (:: ∂☆{N})(:: AbstractZeroBundle{N, typeof(map)} , f:: ATB{N} , args:: ATB{N} ...) where {N}
270
- ∂☆recurse {N} ()(ZeroBundle {N, typeof(map)} (map), f, args... )
304
+ function (:: ∂☆{N,E })(:: AbstractZeroBundle{N, typeof(map)} , f:: ATB{N} , args:: ATB{N} ...) where {N, E }
305
+ ∂☆recurse {N,E } ()(ZeroBundle {N, typeof(map)} (map), f, args... )
271
306
end
272
307
273
308
@@ -279,29 +314,29 @@ function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N
279
314
Core. ifelse (arg. primal, args... )
280
315
end
281
316
282
- struct FwdIterate{N, T<: AbstractTangentBundle{N} }
317
+ struct FwdIterate{N, E, T<: AbstractTangentBundle{N} }
283
318
f:: T
284
319
end
285
- function (f:: FwdIterate )(arg:: ATB{N} ) where {N}
286
- r = ∂☆ {N} ()(f. f, arg)
320
+ FwdIterate {E} (f:: T ) where {N, E, T<: AbstractTangentBundle{N} } = FwdIterate {N,E,T} (f)
321
+ function (f:: FwdIterate{N,E} )(arg:: ATB{N} ) where {N,E}
322
+ r = ∂☆ {N,E} ()(f. f, arg)
287
323
# `primal(r) === nothing` would work, but doesn't create `Conditional` in inference
288
324
isa (r, ATB{N, Nothing}) && return nothing
289
- (∂☆ {N} ()(ZeroBundle {N} (getindex), r, ZeroBundle {N} (1 )),
290
- primal (∂☆ {N} ()(ZeroBundle {N} (getindex), r, ZeroBundle {N} (2 ))))
325
+ (∂☆ {N,E } ()(ZeroBundle {N} (getindex), r, ZeroBundle {N} (1 )),
326
+ primal (∂☆ {N,E } ()(ZeroBundle {N} (getindex), r, ZeroBundle {N} (2 ))))
291
327
end
292
- @Base . constprop :aggressive function (f:: FwdIterate )(arg:: ATB{N} , st) where {N}
293
- r = ∂☆ {N} ()(f. f, arg, ZeroBundle {N} (st))
328
+ @Base . constprop :aggressive function (f:: FwdIterate{N,E} )(arg:: ATB{N} , st) where {N,E }
329
+ r = ∂☆ {N,E } ()(f. f, arg, ZeroBundle {N} (st))
294
330
# `primal(r) === nothing` would work, but doesn't create `Conditional` in inference
295
331
isa (r, ATB{N, Nothing}) && return nothing
296
332
(∂☆ {N} ()(ZeroBundle {N} (getindex), r, ZeroBundle {N} (1 )),
297
- primal (∂☆ {N} ()(ZeroBundle {N} (getindex), r, ZeroBundle {N} (2 ))))
333
+ primal (∂☆ {N,E } ()(ZeroBundle {N} (getindex), r, ZeroBundle {N} (2 ))))
298
334
end
299
335
300
- function (this: :∂☆ {N})(:: AbstractZeroBundle{N, typeof(Core._apply_iterate)} , iterate:: ATB{N} , f:: ATB{N} , args:: ATB{N} ...) where {N}
301
- Core. _apply_iterate (FwdIterate (iterate), this, (f,), args... )
336
+ function (this: :∂☆ {N,E })(:: AbstractZeroBundle{N, typeof(Core._apply_iterate)} , iterate:: ATB{N} , f:: ATB{N} , args:: ATB{N} ...) where {N,E }
337
+ Core. _apply_iterate (FwdIterate {E} (iterate), this, (f,), args... )
302
338
end
303
339
304
-
305
340
function (this: :∂☆ {N})(:: AbstractZeroBundle{N, typeof(iterate)} , t:: TaylorBundle{N, <:Tuple} ) where {N}
306
341
r = iterate (destructure (t))
307
342
r === nothing && return ZeroBundle {N} (nothing )
0 commit comments