Skip to content

Commit c49b880

Browse files
committed
=Add more shuffle_up_bundles
squash me
1 parent 2c88339 commit c49b880

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

src/stage1/forward.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,23 +95,43 @@ function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
9595
end
9696
@ChainRulesCore.non_differentiable shuffle_up(r::UniformBundle)
9797

98-
function shuffle_up_bundle(r::Diffractor.TaylorBundle{1, <:TaylorBundle{<:Any, Float64}})
98+
function shuffle_up_bundle(r::Diffractor.TangentBundle{1, <:TaylorBundle{1}})
9999
a = primal(r)
100100
b = partial(r, 1)
101101
z₀ = primal(a)
102102
z₁ = partial(a, 1)
103103
z₂ = b.primal
104-
z₁₂ = b.tagent
104+
z₁₂ = only(b.tangent.coeffs)
105+
106+
if z₁ == z₂
107+
return TaylorBundle{2}(z₀, (z₁, z₁₂))
108+
else
109+
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
110+
end
111+
end
112+
113+
function shuffle_up_bundle(r::Diffractor.TangentBundle{1, <:ExplicitTangentBundle{1}})
114+
a = primal(r)
115+
b = partial(r, 1)
116+
z₀ = primal(a)
117+
z₁ = partial(a, 1)
118+
z₂ = b.primal
119+
z₂ = b.primal
120+
z₁₂ = only(b.tangent.coeffs)
105121
if z₁ == z₂
106122
return TaylorBundle{2}(z₀, (z₁, z₁₂))
107123
else
108124
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
109125
end
110126
end
111127

128+
112129
function shuffle_up_bundle(r::UniformBundle{1, <:UniformBundle{N, B, U}}) where {N, B, U}
113130
return UniformBundle{N+1, B, U}(primal(primal(r)))
114131
end
132+
function shuffle_up_bundle(r::UniformBundle{1, <:UniformBundle{1, B, U}}) where {B, U} # break ambig
133+
return UniformBundle{2, B, U}(primal(primal(r)))
134+
end
115135

116136
function shuffle_down_bundle(b::ExplicitTangentBundle{N, B}) where {N, B}
117137
error("TODO")
@@ -166,6 +186,7 @@ end
166186
function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
167187
∂☆p = ∂☆{N-1}()
168188
downargs = map(shuffle_down, args)
189+
#@info "∂☆shuffle{N}" args downargs
169190
tupargs = ∂vararg{N-1}()(map(first_partial, downargs)...)
170191
∂☆p(ZeroBundle{N-1}(frule), #= ZeroBundle{N-1}(DiffractorRuleConfig()), =# tupargs, map(primal, downargs)...)
171192
end

0 commit comments

Comments
 (0)