Skip to content

Commit be9386c

Browse files
authored
Fix ReverseDiff bug (#278)
* Fix ReverseDiff bug * Add tests for MWE * Add tests with +
1 parent 3987311 commit be9386c

File tree

3 files changed

+26
-24
lines changed

3 files changed

+26
-24
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "1.4.0"
3+
version = "1.4.1"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -15,8 +15,9 @@ julia = "1.6"
1515
[extras]
1616
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
1717
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
18+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1819
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1920
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2021

2122
[targets]
22-
test = ["Aqua", "Test", "Base64", "StaticArrays"]
23+
test = ["Aqua", "Test", "Base64", "ReverseDiff", "StaticArrays"]

src/fillbroadcast.jl

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,18 @@ function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::O
195195
return _range_convert(AbstractVector{TT}, a)
196196
end
197197

198-
# To fix AD issues with `broadcast(T, x)`
199-
# Avoids type inference issues with x -> T(x)
200-
struct Constructor{T} end
201-
202-
function (::Constructor{T})(x) where {T}
203-
return T(x)
204-
end
205-
206198
for op in (:+, :-)
207199
@eval begin
208200
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractVector, b::ZerosVector)
209201
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
210202
TT = typeof($op(zero(eltype(a)), zero(eltype(b))))
211-
eltype(a) === TT ? a : broadcasted(Constructor{TT}(), a)
203+
# Use `TT ∘ (+)` to fix AD issues with `broadcasted(TT, x)`
204+
eltype(a) === TT ? a : broadcasted(TT (+), a)
205+
end
206+
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::ZerosVector, b::AbstractVector)
207+
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first."))
208+
TT = typeof($op(zero(eltype(a)), zero(eltype(b))))
209+
$op === (+) && eltype(b) === TT ? b : broadcasted(TT ($op), b)
212210
end
213211

214212
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractFillVector, b::ZerosVector) =
@@ -219,18 +217,6 @@ for op in (:+, :-)
219217
end
220218
end
221219

222-
function broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::ZerosVector, b::AbstractVector)
223-
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first."))
224-
TT = typeof(zero(eltype(a)) + zero(eltype(b)))
225-
eltype(b) === TT ? b : broadcasted(Constructor{TT}(), b)
226-
end
227-
228-
function broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::ZerosVector, b::AbstractVector)
229-
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first."))
230-
TT = typeof(zero(eltype(a)) - zero(eltype(b)))
231-
broadcasted(TT (-), b)
232-
end
233-
234220
# Need to prevent array-valued fills from broadcasting over entry
235221
_broadcast_getindex_value(a::AbstractFill{<:Number}) = getindex_value(a)
236222
_broadcast_getindex_value(a::AbstractFill) = Ref(getindex_value(a))

test/runtests.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using FillArrays, LinearAlgebra, SparseArrays, StaticArrays, Random, Base64, Test, Statistics
1+
using FillArrays, LinearAlgebra, SparseArrays, StaticArrays, ReverseDiff, Random, Base64, Test, Statistics
22
import FillArrays: AbstractFill, RectDiagonal, SquareEye
33

44
using Aqua
@@ -2071,3 +2071,18 @@ end
20712071
end
20722072
end
20732073
end
2074+
2075+
@testset "ReverseDiff with Zeros" begin
2076+
# MWE in https://github.com/JuliaArrays/FillArrays.jl/issues/252
2077+
@test ReverseDiff.gradient(x -> sum(abs2.((Zeros(5) .- zeros(5)) ./ x)), rand(5)) == zeros(5)
2078+
@test ReverseDiff.gradient(x -> sum(abs2.((zeros(5) .- Zeros(5)) ./ x)), rand(5)) == zeros(5)
2079+
# MWE in https://github.com/JuliaArrays/FillArrays.jl/pull/278
2080+
@test ReverseDiff.gradient(x -> sum(abs2.((Zeros{eltype(x)}(5) .- zeros(5)) ./ x)), rand(5)) == zeros(5)
2081+
@test ReverseDiff.gradient(x -> sum(abs2.((zeros(5) .- Zeros{eltype(x)}(5)) ./ x)), rand(5)) == zeros(5)
2082+
2083+
# Corresponding tests with +
2084+
@test ReverseDiff.gradient(x -> sum(abs2.((Zeros(5) .+ zeros(5)) ./ x)), rand(5)) == zeros(5)
2085+
@test ReverseDiff.gradient(x -> sum(abs2.((zeros(5) .+ Zeros(5)) ./ x)), rand(5)) == zeros(5)
2086+
@test ReverseDiff.gradient(x -> sum(abs2.((Zeros{eltype(x)}(5) .+ zeros(5)) ./ x)), rand(5)) == zeros(5)
2087+
@test ReverseDiff.gradient(x -> sum(abs2.((zeros(5) .+ Zeros{eltype(x)}(5)) ./ x)), rand(5)) == zeros(5)
2088+
end

0 commit comments

Comments
 (0)