Skip to content

Commit 52e8cce

Browse files
author
oscarddssmith
committed
separate minor changes from #158
1 parent 1e8eedb commit 52e8cce

File tree

2 files changed

+12
-19
lines changed

2 files changed

+12
-19
lines changed

src/interface.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,14 @@ dx(x::Complex) = error("Tried to take the gradient of a complex-valued function.
5959
dx(x) = error("Cotangent space not defined for `$(typeof(x))`. Try a real-valued function.")
6060

6161
"""
62-
x(x)
62+
xⁿ{N}(x)
6363
64-
For `x` in a one dimensional manifold, map x to the trivial, unital, 1st order
65-
tangent bundle. It should hold that `∀x ⟨∂x(x), dx(x)⟩ = 1`
64+
For `x` in a one dimensional manifold, map x to the trivial, unital, Nth order
65+
tangent bundle. It should hold that `∀x ⟨∂ⁿ{1}x(x), dx(x)⟩ = 1`
6666
"""
67-
∂x(x::Real) = ExplicitTangentBundle{1}(x, (one(x),))
68-
∂x(x) = error("Tangent space not defined for `$(typeof(x)).")
69-
7067
struct ∂xⁿ{N}; end
7168

72-
(::∂xⁿ{N})(x::Real) where {N} = TaylorBundle{N}(x, (one(x), (zero(x) for i = 1:(N-1))...,))
69+
(::∂xⁿ{N})(x::Real) where {N} = TaylorBundle{N}(x, ntuple(i->i==1 ? one(x) : zero(x), N))
7370
(::∂xⁿ)(x) = error("Tangent space not defined for `$(typeof(x)).")
7471

7572
function ChainRules.rrule(∂::∂xⁿ, x)
@@ -173,11 +170,6 @@ raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{N+1,T}(get
173170

174171
(f::PrimeDerivativeFwd{0})(x) = getfield(f, :f)(x)
175172

176-
function (f::PrimeDerivativeFwd{1})(x)
177-
z = ∂☆¹(ZeroBundle{1}(getfield(f, :f)), ∂x(x))
178-
z[TaylorTangentIndex(1)]
179-
end
180-
181173
function (f::PrimeDerivativeFwd{N})(x) where N
182174
z = ∂☆{N}()(ZeroBundle{N}(getfield(f, :f)), ∂xⁿ{N}()(x))
183175
z[TaylorTangentIndex(N)]

src/tangent.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,6 @@ function unbundle(atb::TaylorBundle{Order, A}) where {Order, Dim, T, A<:Abstract
334334
StructArray{TaylorBundle{Order, T}}((atb.primal, atb.tangent.coeffs...))
335335
end
336336

337-
function ChainRulesCore.rrule(::typeof(unbundle), atb::TaylorBundle)
338-
unbundle(atb), Δ->throw(Δ)
339-
end
340-
341337
function StructArrays.staticschema(::Type{<:TaylorBundle{N, B, T}}) where {N, B, T}
342338
Tuple{B, T.parameters...}
343339
end
@@ -355,11 +351,11 @@ function StructArrays.createinstance(T::Type{<:TaylorBundle}, args...)
355351
T(first(args), Base.tail(args))
356352
end
357353

358-
function unbundle(zb::ZeroBundle{N, A}) where {N,T,Dim,A<:AbstractArray{T, Dim}}
359-
StructArray{ZeroBundle{N, T}}((zb.primal, fill(zb.tangent.val, size(zb.primal)...)))
354+
function unbundle(u::UniformBundle{N, A}) where {N,T,Dim,A<:AbstractArray{T, Dim}}
355+
StructArray{UniformBundle{N, T}}((u.primal, fill(u.tangent.val, size(u.primal)...)))
360356
end
361357

362-
function ChainRulesCore.rrule(::typeof(unbundle), atb::ZeroBundle)
358+
function ChainRulesCore.rrule(::typeof(unbundle), atb::AbstractTangentBundle)
363359
unbundle(atb), Δ->throw(Δ)
364360
end
365361

@@ -383,6 +379,11 @@ function rebundle(A::AbstractArray{<:TaylorBundle{N}}) where {N}
383379
end)
384380
end
385381

382+
function rebundle(A::AbstractArray{<:UniformBundle{N}}) where {N}
383+
@assert all(x->getfield(x, :tangent)==(first(A).tangent), A)
384+
UniformBundle{N}(map(x->x.primal, A), first(A).tangent.val)
385+
end
386+
386387
function ChainRulesCore.rrule(::typeof(rebundle), atb)
387388
rebundle(atb), Δ->throw(Δ)
388389
end

0 commit comments

Comments
 (0)