Skip to content

Commit 4fba774

Browse files
committed
Use higher order to evaluate repeated first-order
remove compositebundle shash
1 parent 80d89a6 commit 4fba774

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

src/stage1/forward.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,37 @@ function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
9393
end
9494
@ChainRulesCore.non_differentiable shuffle_up(r::UniformBundle)
9595

96+
function shuffle_up_bundle(r::Diffractor.TaylorBundle{1, <:TaylorBundle{<:Any, Float64}})
97+
a = primal(r)
98+
b = partial(r, 1)
99+
z₀ = primal(a)
100+
z₁ = partial(a, 1)
101+
z₂ = b.primal
102+
z₁₂ = b.tagent
103+
if z₁ == z₂
104+
return TaylorBundle{2}(z₀, (z₁, z₁₂))
105+
else
106+
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
107+
end
108+
end
109+
110+
function shuffle_up_bundle(r::UniformBundle{1, <:UniformBundle{N, B, U}}) where {N, B, U}
111+
return UniformBundle{N+1, B, U}(primal(primal(r)))
112+
end
113+
114+
function shuffle_down_bundle(b::ExplicitTangentBundle{N, B}) where {N, B}
115+
error("TODO")
116+
end
117+
118+
function shuffle_down_bundle(b::TaylorBundle{2, B}) where {B}
119+
z₀ = primal(b)
120+
z₁ = b.tangent.coeffs[1]
121+
z₁₂ = b.tangent.coeffs[2]
122+
TaylorBundle{1}(TaylorBundle{1}(z₀, (z₁,)), (TaylorBundle{1}(z₁, (z₁₂,)),))
123+
end
124+
96125
struct ∂☆internal{N}; end
126+
struct ∂☆recurse{N}; end
97127
struct ∂☆shuffle{N}; end
98128

99129
function shuffle_base(r)
@@ -158,8 +188,13 @@ function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}
158188
return shuffle_up(r)
159189
end
160190
end
161-
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)
162191

192+
# TODO: Generalize to N,M
193+
@inline function (::∂☆{1})(rec::AbstractZeroBundle{1, ∂☆recurse{1}}, args::ATB{1}...)
194+
return shuffle_down_bundle(∂☆recurse{2}()(map(shuffle_up_bundle, args)...))
195+
end
196+
197+
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)
163198

164199
# Special case rules for performance
165200
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}

src/stage1/recurse_fwd.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
struct ∂☆recurse{N}; end
21

32
struct ∂vararg{N}; end
43

0 commit comments

Comments
 (0)