@@ -93,7 +93,37 @@ function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
93
93
end
94
94
@ChainRulesCore . non_differentiable shuffle_up (r:: UniformBundle )
95
95
96
+ function shuffle_up_bundle (r:: Diffractor.TaylorBundle{1, <:TaylorBundle{<:Any, Float64}} )
97
+ a = primal (r)
98
+ b = partial (r, 1 )
99
+ z₀ = primal (a)
100
+ z₁ = partial (a, 1 )
101
+ z₂ = b. primal
102
+ z₁₂ = b. tagent
103
+ if z₁ == z₂
104
+ return TaylorBundle {2} (z₀, (z₁, z₁₂))
105
+ else
106
+ return ExplicitTangentBundle {2} (z₀, (z₁, z₂, z₁₂))
107
+ end
108
+ end
109
+
110
+ function shuffle_up_bundle (r:: UniformBundle{1, <:UniformBundle{N, B, U}} ) where {N, B, U}
111
+ return UniformBundle {N+1, B, U} (primal (primal (r)))
112
+ end
113
+
114
+ function shuffle_down_bundle (b:: ExplicitTangentBundle{N, B} ) where {N, B}
115
+ error (" TODO" )
116
+ end
117
+
118
+ function shuffle_down_bundle (b:: TaylorBundle{2, B} ) where {B}
119
+ z₀ = primal (b)
120
+ z₁ = b. tangent. coeffs[1 ]
121
+ z₁₂ = b. tangent. coeffs[2 ]
122
+ TaylorBundle {1} (TaylorBundle {1} (z₀, (z₁,)), (TaylorBundle {1} (z₁, (z₁₂,)),))
123
+ end
124
+
96
125
struct ∂☆internal{N}; end
126
+ struct ∂☆recurse{N}; end
97
127
struct ∂☆shuffle{N}; end
98
128
99
129
function shuffle_base (r)
@@ -158,8 +188,13 @@ function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}
158
188
return shuffle_up (r)
159
189
end
160
190
end
161
- (:: ∂☆{N})(args:: AbstractTangentBundle{N} ...) where {N} = ∂☆internal {N} ()(args... )
162
191
192
+ # TODO : Generalize to N,M
193
+ @inline function (:: ∂☆{1 })(rec:: AbstractZeroBundle{1, ∂☆recurse{1}} , args:: ATB{1} ...)
194
+ return shuffle_down_bundle (∂☆recurse {2} ()(map (shuffle_up_bundle, args)... ))
195
+ end
196
+
197
+ (:: ∂☆{N})(args:: AbstractTangentBundle{N} ...) where {N} = ∂☆internal {N} ()(args... )
163
198
164
199
# Special case rules for performance
165
200
@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: TangentBundle{N} , s:: AbstractTangentBundle{N} ) where {N}
0 commit comments