1
+ using Base. Broadcast
2
+ using Base. Broadcast: broadcasted
3
+
4
+ # Broadcast over one element is just map
5
+ function (∂⃖ₙ: :∂⃖ {N})(:: typeof (broadcasted), f, a:: Array ) where {N}
6
+ ∂⃖ₙ (map, f, a)
7
+ end
8
+
9
+ # The below is from Zygote: TODO : DO we want to do something better here?
10
+
11
+ accum_sum (xs:: Nothing ; dims = :) = NoTangent ()
12
+ accum_sum (xs:: AbstractArray{Nothing} ; dims = :) = NoTangent ()
13
+ accum_sum (xs:: AbstractArray{<:Number} ; dims = :) = sum (xs, dims = dims)
14
+ accum_sum (xs:: AbstractArray{<:AbstractArray{<:Number}} ; dims = :) = sum (xs, dims = dims)
15
+ accum_sum (xs:: Number ; dims = :) = xs
16
+
17
+ # https://github.com/FluxML/Zygote.jl/issues/594
18
+ function Base. reducedim_init (:: typeof (identity), :: typeof (accum), A:: AbstractArray , region)
19
+ Base. reducedim_initarray (A, region, NoTangent (), Union{Nothing,eltype (A)})
20
+ end
21
+
22
+ trim (x, Δ) = reshape (Δ, ntuple (i -> size (Δ, i), Val (ndims (x))))
23
+
24
+ unbroadcast (x:: AbstractArray , x̄) =
25
+ size (x) == size (x̄) ? x̄ :
26
+ length (x) == length (x̄) ? trim (x, x̄) :
27
+ trim (x, accum_sum (x̄, dims = ntuple (i -> size (x, i) == 1 ? i : ndims (x̄)+ 1 , Val (ndims (x̄)))))
28
+
29
+ unbroadcast (x:: Number , x̄) = accum_sum (x̄)
30
+ unbroadcast (x:: Tuple{<:Any} , x̄) = (accum_sum (x̄),)
31
+ unbroadcast (x:: Base.RefValue , x̄) = (x= accum_sum (x̄),)
32
+
33
+ unbroadcast (x:: AbstractArray , x̄:: Nothing ) = NoTangent ()
34
+
35
+ const Numeric = Union{Number, AbstractArray{<: Number , N} where N}
36
+
37
+ function ChainRulesCore. rrule (:: typeof (broadcasted), :: typeof (+ ), xs:: Numeric... )
38
+ broadcast (+ , xs... ), ȳ -> (NoTangent (), NoTangent (), map (x -> unbroadcast (x, unthunk (ȳ)), xs)... )
39
+ end
40
+
41
+ ChainRulesCore. rrule (:: typeof (broadcasted), :: typeof (- ), x:: Numeric , y:: Numeric ) = x .- y,
42
+ Δ -> let Δ= unthunk (Δ); (NoTangent (), NoTangent (), unbroadcast (x, Δ), - unbroadcast (y, Δ)); end
43
+
44
+ ChainRulesCore. rrule (:: typeof (broadcasted), :: typeof (* ), x:: Numeric , y:: Numeric ) = x.* y,
45
+ z̄ -> let z̄= unthunk (z̄); (NoTangent (), NoTangent (), unbroadcast (x, z̄ .* conj .(y)), unbroadcast (y, z̄ .* conj .(x))); end
0 commit comments