Skip to content

Commit a83c41d

Browse files
committed
Add missing file
1 parent 193a683 commit a83c41d

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

src/stage1/broadcast.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using Base.Broadcast
2+
using Base.Broadcast: broadcasted
3+
4+
# Broadcast over one element is just map
5+
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
6+
∂⃖ₙ(map, f, a)
7+
end
8+
9+
# The below is from Zygote: TODO: DO we want to do something better here?
10+
11+
accum_sum(xs::Nothing; dims = :) = NoTangent()
12+
accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent()
13+
accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims)
14+
accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims)
15+
accum_sum(xs::Number; dims = :) = xs
16+
17+
# https://github.com/FluxML/Zygote.jl/issues/594
18+
function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region)
19+
Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)})
20+
end
21+
22+
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
23+
24+
unbroadcast(x::AbstractArray, x̄) =
25+
size(x) == size(x̄) ?:
26+
length(x) == length(x̄) ? trim(x, x̄) :
27+
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
28+
29+
unbroadcast(x::Number, x̄) = accum_sum(x̄)
30+
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
31+
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
32+
33+
unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()
34+
35+
const Numeric = Union{Number, AbstractArray{<:Number, N} where N}
36+
37+
function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...)
38+
broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
39+
end
40+
41+
ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
42+
Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
43+
44+
ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
45+
-> let=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end

0 commit comments

Comments
 (0)