Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.15.6"
version = "1.16.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
6 changes: 4 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,11 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
# Apply `muladd` iteratively.
# Explicit multiplication is only performed for the first pair of partial and gradient.
(∂s_1, Δs_1), _∂s_Δs_tail = Iterators.peel(zip(_∂s, Δs))
init_expr = :($∂s_1 * $Δs_1)
# zero gradients are treated as hard zeros. This avoids propagation of NaNs when
# partials are non-finite
init_expr = :((iszero($Δs_1) ? zero($∂s_1) : $∂s_1) * $Δs_1)
summed_∂_mul_Δs = foldl(_∂s_Δs_tail; init=init_expr) do ex, (∂s_i, Δs_i)
:(muladd($∂s_i, $Δs_i, $ex))
:(muladd((iszero($Δs_i) ? zero($∂s_i) : $∂s_i), $Δs_i, $ex))
end
return :($proj($summed_∂_mul_Δs))
end
Expand Down
3 changes: 3 additions & 0 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,6 @@ arguments.
```
"""
struct NoTangent <: AbstractZero end

Base.zero(::NoTangent) = NoTangent()
Base.zero(::Type{NoTangent}) = NoTangent()
23 changes: 23 additions & 0 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,29 @@ end
@test (NoTangent(), 0.0 - 1.0im) === rrule(make_imaginary, 2.0im)[2](1.0)
end

@testset "@scalar_rule strong zero (co)tangents" begin
suminv(x, y) = inv(x) + inv(y)
@scalar_rule suminv(x, y) (-(inv(x)^2), -(inv(y)^2))

@test frule((NoTangent(), 1.0, 1.0), suminv, 0.0, 1.0) === (Inf, -Inf)
@test frule((NoTangent(), ZeroTangent(), 1.0), suminv, 0.0, 1.0) === (Inf, -1.0)
@test frule((NoTangent(), 0.0, 1.0), suminv, 0.0, 1.0) === (Inf, -1.0)

@test frule((NoTangent(), 1.0, 1.0), suminv, 1.0, 0.0) === (Inf, -Inf)
@test frule((NoTangent(), 1.0, ZeroTangent()), suminv, 1.0, 0.0) === (Inf, -1.0)
@test frule((NoTangent(), 1.0, 0.0), suminv, 1.0, 0.0) === (Inf, -1.0)

@test rrule(suminv, 0.0, 1.0)[2](1.0) === (NoTangent(), -Inf, -1.0)
@test rrule(suminv, 0.0, 1.0)[2](ZeroTangent()) ===
(NoTangent(), ZeroTangent(), ZeroTangent())
@test rrule(suminv, 0.0, 1.0)[2](0.0) === (NoTangent(), 0.0, 0.0)

@test rrule(suminv, 1.0, 0.0)[2](1.0) === (NoTangent(), -1.0, -Inf)
@test rrule(suminv, 1.0, 0.0)[2](ZeroTangent()) ===
(NoTangent(), ZeroTangent(), ZeroTangent())
@test rrule(suminv, 1.0, 0.0)[2](0.0) === (NoTangent(), 0.0, 0.0)
end

@testset "Regression tests against #276 and #265" begin
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265
Expand Down
4 changes: 2 additions & 2 deletions test/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@
end
@test broadcastable(z) isa Ref{ZeroTangent}
@test zero(@thunk(3)) === z
@test zero(NoTangent()) === z
@test zero(ZeroTangent) === z
@test zero(NoTangent) === z
@test zero(Tangent{Tuple{Int,Int}}((1, 2))) === z
for f in (transpose, adjoint, conj)
@test f(z) === z
Expand Down Expand Up @@ -94,6 +92,8 @@

@testset "NoTangent" begin
dne = NoTangent()
@test zero(dne) === NoTangent()
@test zero(NoTangent) === NoTangent()
@test dne + dne == dne
@test dne + 1 == 1
@test 1 + dne == 1
Expand Down