Skip to content

Commit 6c877da

Browse files
committed
and shuffle_up_bundle for UniformBundle and refactor
1 parent a643d7f commit 6c877da

File tree

1 file changed

+6
-16
lines changed

1 file changed

+6
-16
lines changed

src/stage1/forward.jl

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,14 @@ function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
9595
end
9696
@ChainRulesCore.non_differentiable shuffle_up(r::UniformBundle)
9797

98-
function shuffle_up_bundle(r::Diffractor.TangentBundle{1, <:TaylorBundle{1}})
98+
99+
function shuffle_up_bundle(r::Diffractor.TangentBundle{1, B}) where {B<:ATB{1}}
99100
a = primal(r)
100101
b = partial(r, 1)
101102
z₀ = primal(a)
102103
z₁ = partial(a, 1)
103104
z₂ = b.primal
104-
z₁₂ = only(b.tangent.coeffs)
105+
z₁₂ = _shuffle_up_partial₁₂(B, b.tangent)
105106

106107
if z₁ == z₂
107108
return TaylorBundle{2}(z₀, (z₁, z₁₂))
@@ -110,20 +111,9 @@ function shuffle_up_bundle(r::Diffractor.TangentBundle{1, <:TaylorBundle{1}})
110111
end
111112
end
112113

113-
function shuffle_up_bundle(r::Diffractor.TangentBundle{1, <:ExplicitTangentBundle{1}})
114-
a = primal(r)
115-
b = partial(r, 1)
116-
z₀ = primal(a)
117-
z₁ = partial(a, 1)
118-
z₂ = b.primal
119-
z₂ = b.primal
120-
z₁₂ = only(b.tangent.coeffs)
121-
if z₁ == z₂
122-
return TaylorBundle{2}(z₀, (z₁, z₁₂))
123-
else
124-
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
125-
end
126-
end
114+
_shuffle_up_partial₁₂(::Type{<:TaylorBundle}, tangent) = only(tangent.coeffs)
115+
_shuffle_up_partial₁₂(::Type{<:ExplicitTangentBundle}, tangent) = only(tangent.partials)
116+
_shuffle_up_partial₁₂(::Type{<:UniformBundle}, tangent) = tangent.val
127117

128118

129119
function shuffle_up_bundle(r::UniformBundle{1, <:UniformBundle{N, B, U}}) where {N, B, U}

0 commit comments

Comments
 (0)