Skip to content

Commit 3106100

Browse files
committed
Higher order forward rule for broadcast
1 parent b925ae0 commit 3106100

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

src/stage1/broadcast.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,33 @@
11
using Base.Broadcast
2-
using Base.Broadcast: broadcasted
2+
using Base.Broadcast: broadcasted, Broadcasted
3+
4+
# Forward mode broadcast rule
5+
struct FwdBroadcast{N, T<:AbstractTangentBundle{N}}
6+
f::T
7+
end
8+
(f::FwdBroadcast{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
9+
10+
n_getfield(∂ₙ::∂☆{N}, b::ATB{N}, x::Union{Symbol, Int}) where {N} = ∂ₙ(ZeroBundle{N}(getfield), b, ZeroBundle{N}(x))
11+
12+
function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
13+
bc::ATB{N, <:Broadcasted}) where {N}
14+
bc = ∂ₙ(ZeroBundle{N}(Broadcast.flatten), bc)
15+
args = n_getfield(∂ₙ, bc, :args)
16+
r = copy(Broadcasted(
17+
FwdMap(n_getfield(∂ₙ, bc, :f)),
18+
ntuple(length(primal(args))) do i
19+
val = n_getfield(∂ₙ, args, i)
20+
if ndims(primal(val)) == 0
21+
return Ref(∂ₙ(ZeroBundle{N}(getindex), val))
22+
else
23+
return unbundle(val)
24+
end
25+
end))
26+
if isa(r, AbstractArray)
27+
r = rebundle(r)
28+
end
29+
return r
30+
end
331

432
# Broadcast over one element is just map
533
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}

src/tangent.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,23 @@ end
223223
$(Expr(:splatnew, B, :x))
224224
end
225225

226+
expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...)
227+
expand_singleton_to_array(asize, a::AbstractArray) = a
228+
226229
function unbundle(atb::TangentBundle{Order, A}) where {Order, Dim, T, A<:AbstractArray{T, Dim}}
227-
StructArray{TangentBundle{Order, T}}((atb.primal, atb.partials...))
230+
asize = size(atb.primal)
231+
StructArray{TangentBundle{Order, T}}((atb.primal, map(a->expand_singleton_to_array(asize, a), atb.partials)...))
228232
end
229233

230234
function StructArrays.staticschema(::Type{<:TangentBundle{N, B, T}}) where {N, B, T}
231235
Tuple{B, T.parameters...}
232236
end
233237

238+
function StructArrays.component(m::TangentBundle{N, B, T}, i::Int) where {N, B, T}
239+
i == 1 && return m.primal
240+
return m.partials[i - 1]
241+
end
242+
234243
function StructArrays.createinstance(T::Type{<:TangentBundle}, args...)
235244
T(first(args), Base.tail(args))
236245
end
@@ -251,6 +260,10 @@ function StructArrays.staticschema(::Type{<:TaylorBundle{N, B}}) where {N, B}
251260
Tuple{B, Vararg{Any, N}}
252261
end
253262

263+
function StructArrays.component(m::TaylorBundle{N, B}, i::Int) where {N, B, T}
264+
i == 1 && return m.primal
265+
return m.coeffs[i - 1]
266+
end
254267

255268
function StructArrays.createinstance(T::Type{<:TaylorBundle}, args...)
256269
T(first(args), Base.tail(args))

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,13 @@ end
114114
# Regression tests
115115
@test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0]
116116

117+
const fwd = Diffractor.PrimeDerivativeFwd
118+
const bwd = Diffractor.PrimeDerivativeFwd
119+
120+
function f_broadcast(a)
121+
l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]]
122+
return sum(l)
123+
end
124+
@test fwd(f_broadcast)(1.0) == bwd(f_broadcast)(1.0)
125+
117126
include("pinn.jl")

0 commit comments

Comments
 (0)