@@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
44partial (x:: UniformTangent , i) = getfield (x, :val )
55partial (x:: ProductTangent , i) = ProductTangent (map (x-> partial (x, i), getfield (x, :factors )))
66partial (x:: AbstractZero , i) = x
7- partial (x:: CompositeBundle{N, B} , i) where {N, B<: Tuple } = Tangent {B} (map (x-> partial (x, i), getfield (x, :tup ))... )
8- function partial (x:: CompositeBundle{N, B} , i) where {N, B}
9- # This is tangent for a struct, but fields partials are each stored in a plain tuple
10- # so we add the names back using the primal `B`
11- # TODO : If required this can be done as a `@generated` function so it is type-stable
12- backing = NamedTuple {fieldnames(B)} (map (x-> partial (x, i), getfield (x, :tup )))
13- return Tangent {B, typeof(backing)} (backing)
14- end
157
168
179primal (x:: AbstractTangentBundle ) = x. primal
@@ -42,20 +34,12 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
4234 ntuple (_sdown, N- 1 ))
4335end
4436
45- function shuffle_down (b:: CompositeBundle{N, B} ) where {N, B}
46- z = CompositeBundle {N-1, CompositeBundle{1, B}} (
47- (CompositeBundle {N-1, Tuple} (
48- map (shuffle_down, b. tup)
49- ),)
50- )
51- z
52- end
5337
54- function shuffle_up (r:: CompositeBundle{1} )
55- z₀ = primal (r. tup [1 ])
56- z₁ = partial (r. tup[ 1 ] , 1 )
57- z₂ = primal (r. tup [2 ])
58- z₁₂ = partial (r. tup[ 2 ] , 1 )
38+ function shuffle_up (r:: TaylorBundle{1, Tuple{B1,B2}} ) where {B1,B2}
39+ z₀ = primal (r) [1 ]
40+ z₁ = partial (r, 1 )[ 1 ]
41+ z₂ = primal (r) [2 ]
42+ z₁₂ = partial (r, 1 )[ 2 ]
5943 if z₁ == z₂
6044 return TaylorBundle {2} (z₀, (z₁, z₁₂))
6145 else
@@ -70,26 +54,33 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
7054 end
7155end
7256
73- # Check whether the tangent bundle element is taylor-like
74- isswifty (:: TaylorBundle ) = true
75- isswifty (:: UniformBundle ) = true
76- isswifty (b:: CompositeBundle ) = all (isswifty, b. tup)
77- isswifty (:: Any ) = false
78-
79- function shuffle_up (r:: CompositeBundle{N} ) where {N}
80- a, b = r. tup
81- if isswifty (a) && isswifty (b) && taylor_compatible (a, b)
82- return TaylorBundle {N+1} (primal (a),
83- ntuple (i-> i == N+ 1 ?
84- b[TaylorTangentIndex (i- 1 )] : a[TaylorTangentIndex (i)],
85- N+ 1 ))
57+ function taylor_compatible (r:: TaylorBundle{N, Tuple{B1,B2}} ) where {N, B1,B2}
58+ partial (r, 1 )[1 ] == primal (r)[2 ] || return false
59+ return all (1 : N- 1 ) do i
60+ partial (r, i+ 1 )[1 ] == partial (r, i)[2 ]
61+ end
62+ end
63+ function shuffle_up (r:: TaylorBundle{N, Tuple{B1,B2}} ) where {N, B1,B2}
64+ the_primal = primal (r)[1 ]
65+ if taylor_compatible (r)
66+ the_partials = ntuple (N+ 1 ) do i
67+ if i <= N
68+ partial (r, i)[1 ] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
69+ else # ii = N+1
70+ partial (r, i- 1 )[2 ]
71+ end
72+ end
73+ return TaylorBundle {N+1} (the_primal, the_partials)
8674 else
87- return TangentBundle {N+1} (r. tup[1 ]. primal,
88- (r. tup[1 ]. tangent. partials... , primal (b),
89- ntuple (i-> partial (b,i), 1 << (N+ 1 )- 1 )... ))
75+ # XXX : am dubious of the correctness of this
76+ a_partials = ntuple (i-> partial (r, i)[1 ], N)
77+ b_partials = ntuple (i-> partial (r, i)[2 ], N)
78+ the_partials = (a_partials... , primal_b, b_partials... )
79+ return TangentBundle {N+1} (the_primal, the_partials)
9080 end
9181end
9282
83+
9384function shuffle_up (r:: UniformBundle{N, B, U} ) where {N, B, U}
9485 (a, b) = primal (r)
9586 if r. tangent. val === b
185176 map (y-> lifted_getfield (y, s), x. tangent. coeffs))
186177end
187178
188- @Base . constprop :aggressive function (:: ∂☆{N})(:: ATB{N, typeof(getfield)} , x:: CompositeBundle{N} , s:: AbstractTangentBundle{N, Int} ) where {N}
189- x. tup[primal (s)]
190- end
191-
192- @Base . constprop :aggressive function (:: ∂☆{N})(:: ATB{N, typeof(getfield)} , x:: CompositeBundle{N, B} , s:: AbstractTangentBundle{N, Symbol} ) where {N, B}
193- x. tup[Base. fieldindex (B, primal (s))]
194- end
195179
196180@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: UniformBundle{N, <:Any, U} , s:: AbstractTangentBundle{N} ) where {N, U}
197181 UniformBundle {N,<:Any,U} (getfield (primal (x), primal (s)), x. tangent. val)
@@ -210,8 +194,8 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
210194end
211195(f:: FwdMap{N} )(args:: AbstractTangentBundle{N} ...) where {N} = ∂☆ {N} ()(f. f, args... )
212196
213- function (:: ∂☆{N})(:: ZeroBundle{N, typeof(map)} , f:: ATB{N} , tup:: CompositeBundle {N, <:Tuple} ) where {N}
214- ∂vararg {N} ()(map (FwdMap (f), tup. tup )... )
197+ function (:: ∂☆{N})(:: ZeroBundle{N, typeof(map)} , f:: ATB{N} , tup:: TaylorBundle {N, <:Tuple} ) where {N}
198+ ∂vararg {N} ()(map (FwdMap (f), destructure ( tup) )... )
215199end
216200
217201function (:: ∂☆{N})(:: ZeroBundle{N, typeof(map)} , f:: ATB{N} , args:: ATB{N, <:AbstractArray} ...) where {N}
@@ -254,35 +238,37 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate
254238 Core. _apply_iterate (FwdIterate (iterate), this, (f,), args... )
255239end
256240
257- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(iterate)} , t:: CompositeBundle{N, <:Tuple} ) where {N}
258- r = iterate (t. tup)
241+
242+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(iterate)} , t:: TaylorBundle{N, <:Tuple} ) where {N}
243+ r = iterate (destructure (t))
259244 r === nothing && return ZeroBundle {N} (nothing )
260245 ∂vararg {N} ()(r[1 ], ZeroBundle {N} (r[2 ]))
261246end
262247
263- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(iterate)} , t:: CompositeBundle {N, <:Tuple} , a:: ATB{N} , args:: ATB{N} ...) where {N}
264- r = iterate (t . tup , primal (a), map (primal, args)... )
248+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(iterate)} , t:: TaylorBundle {N, <:Tuple} , a:: ATB{N} , args:: ATB{N} ...) where {N}
249+ r = iterate (destructure (t) , primal (a), map (primal, args)... )
265250 r === nothing && return ZeroBundle {N} (nothing )
266251 ∂vararg {N} ()(r[1 ], ZeroBundle {N} (r[2 ]))
267252end
268253
269- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: CompositeBundle {N, <:Tuple} , i:: ATB{N} ) where {N}
270- r = Base. indexed_iterate (t . tup , primal (i))
254+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: TaylorBundle {N, <:Tuple} , i:: ATB{N} ) where {N}
255+ r = Base. indexed_iterate (destructure (t) , primal (i))
271256 ∂vararg {N} ()(r[1 ], ZeroBundle {N} (r[2 ]))
272257end
273258
274- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: CompositeBundle {N, <:Tuple} , i:: ATB{N} , st1:: ATB{N} , st:: ATB{N} ...) where {N}
275- r = Base. indexed_iterate (t . tup , primal (i), primal (st1), map (primal, st)... )
259+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: TaylorBundle {N, <:Tuple} , i:: ATB{N} , st1:: ATB{N} , st:: ATB{N} ...) where {N}
260+ r = Base. indexed_iterate (destructure (t) , primal (i), primal (st1), map (primal, st)... )
276261 ∂vararg {N} ()(r[1 ], ZeroBundle {N} (r[2 ]))
277262end
278263
279264function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: TangentBundle{N, <:Tuple} , i:: ATB{N} , st:: ATB{N} ...) where {N}
280265 ∂vararg {N} ()(this (ZeroBundle {N} (getfield), t, i), ZeroBundle {N} (primal (i) + 1 ))
281266end
282267
283-
284- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(getindex)} , t:: CompositeBundle{N, <:Tuple} , i:: ZeroBundle ) where {N}
285- t. tup[primal (i)]
268+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(getindex)} , t:: TaylorBundle{N, <:Tuple} , i:: ZeroBundle ) where {N}
269+ field_ind = primal (i)
270+ the_partials = ntuple (order_ind-> partial (t, order_ind)[field_ind], N)
271+ TaylorBundle {N} (primal (t)[field_ind], the_partials)
286272end
287273
288274function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(typeof)} , x:: ATB{N} ) where {N}
0 commit comments