Skip to content

Commit 6162295

Browse files
committed
remove reverse mode broadcasting rules
1 parent d50f6e8 commit 6162295

File tree

3 files changed

+47
-49
lines changed

3 files changed

+47
-49
lines changed

src/extra_rules.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,6 @@ end
150150

151151
ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()
152152

153-
# Skip AD'ing through the axis computation
154-
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
155-
return Base.Broadcast.instantiate(bc), Δ->begin
156-
Core.tuple(NoTangent(), Δ)
157-
end
158-
end
159153

160154

161155
using StaticArrays

src/stage1/broadcast.jl

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,46 +28,3 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
2828
end
2929
return r
3030
end
31-
32-
# Broadcast over one element is just map
33-
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
34-
∂⃖ₙ(map, f, a)
35-
end
36-
37-
# The below is from Zygote: TODO: DO we want to do something better here?
38-
39-
accum_sum(xs::Nothing; dims = :) = NoTangent()
40-
accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent()
41-
accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims)
42-
accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims)
43-
accum_sum(xs::Number; dims = :) = xs
44-
45-
# https://github.com/FluxML/Zygote.jl/issues/594
46-
function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region)
47-
Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)})
48-
end
49-
50-
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
51-
52-
unbroadcast(x::AbstractArray, x̄) =
53-
size(x) == size(x̄) ?:
54-
length(x) == length(x̄) ? trim(x, x̄) :
55-
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
56-
57-
unbroadcast(x::Number, x̄) = accum_sum(x̄)
58-
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
59-
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
60-
61-
unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()
62-
63-
const Numeric = Union{Number, AbstractArray{<:Number, N} where N}
64-
65-
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(+), xs::Numeric...)
66-
broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
67-
end
68-
69-
ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
70-
Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
71-
72-
ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
73-
-> let=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end

test/runtests.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,53 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
219219
@test z45 2.0
220220
@test delta45 1.0
221221

222+
@testset "broadcast" begin
223+
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
224+
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
225+
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
226+
227+
@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
228+
exp_log(x) = exp(log(x))
229+
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
230+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
231+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
232+
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure
233+
234+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays
235+
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
236+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
237+
@test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] [12, 12, 12] # must not take the * fast path
238+
239+
@test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],)
240+
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule
241+
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule
242+
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule
243+
244+
@test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),)
245+
@test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),)
246+
@test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),)
247+
248+
@test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero # Bool output
249+
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero
250+
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent())
251+
@test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero # Bool input
252+
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero
253+
@test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero
254+
255+
tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5]))
256+
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0)
257+
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
258+
@test tup_adj[2] isa Transpose
259+
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal
260+
end
261+
262+
@testset "broadcast, 2nd order" begin
263+
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] exp.(1:3) # MethodError: no method matching copy(::Nothing)
264+
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3.0])[1] exp.(1:3)
265+
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # ERROR: (1, current_logger_for_env(std_level::Base.CoreLogging.LogLevel, group, _module) @ Base.CoreLogging logging.jl:500, :($(Expr(:meta, :noinline))))
266+
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]
267+
end
268+
222269
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
223270
#include("pinn.jl")
224271

0 commit comments

Comments
 (0)