@@ -95,13 +95,14 @@ function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
95
95
end
96
96
@ChainRulesCore . non_differentiable shuffle_up (r:: UniformBundle )
97
97
98
- function shuffle_up_bundle (r:: Diffractor.TangentBundle{1, <:TaylorBundle{1}} )
98
+
99
+ function shuffle_up_bundle (r:: Diffractor.TangentBundle{1, B} ) where {B<: ATB{1} }
99
100
a = primal (r)
100
101
b = partial (r, 1 )
101
102
z₀ = primal (a)
102
103
z₁ = partial (a, 1 )
103
104
z₂ = b. primal
104
- z₁₂ = only ( b. tangent. coeffs )
105
+ z₁₂ = _shuffle_up_partial₁₂ (B, b. tangent)
105
106
106
107
if z₁ == z₂
107
108
return TaylorBundle {2} (z₀, (z₁, z₁₂))
@@ -110,20 +111,9 @@ function shuffle_up_bundle(r::Diffractor.TangentBundle{1, <:TaylorBundle{1}})
110
111
end
111
112
end
112
113
113
- function shuffle_up_bundle (r:: Diffractor.TangentBundle{1, <:ExplicitTangentBundle{1}} )
114
- a = primal (r)
115
- b = partial (r, 1 )
116
- z₀ = primal (a)
117
- z₁ = partial (a, 1 )
118
- z₂ = b. primal
119
- z₂ = b. primal
120
- z₁₂ = only (b. tangent. coeffs)
121
- if z₁ == z₂
122
- return TaylorBundle {2} (z₀, (z₁, z₁₂))
123
- else
124
- return ExplicitTangentBundle {2} (z₀, (z₁, z₂, z₁₂))
125
- end
126
- end
114
+ _shuffle_up_partial₁₂ (:: Type{<:TaylorBundle} , tangent) = only (tangent. coeffs)
115
+ _shuffle_up_partial₁₂ (:: Type{<:ExplicitTangentBundle} , tangent) = only (tangent. partials)
116
+ _shuffle_up_partial₁₂ (:: Type{<:UniformBundle} , tangent) = tangent. val
127
117
128
118
129
119
function shuffle_up_bundle (r:: UniformBundle{1, <:UniformBundle{N, B, U}} ) where {N, B, U}
0 commit comments