Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.5.0"
version = "0.5.1"

[compat]
julia = "^1.0"
Expand Down
8 changes: 4 additions & 4 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ Base.:*(::DoesNotExist, ::Zero) = Zero()
Base.:*(::Zero, ::DoesNotExist) = Zero()


Base.:+(::Zero, b::Zero) = Zero()
Base.:+(::Zero, ::Zero) = Zero()
Base.:*(::Zero, ::Zero) = Zero()
for T in (:One, :AbstractThunk, :Any)
@eval Base.:+(::Zero, b::$T) = b
@eval Base.:+(a::$T, ::Zero) = a

@eval Base.:*(::Zero, ::$T) = Zero()
@eval Base.:*(::$T, ::Zero) = Zero()
@eval Base.:*(::Zero, x::$T) = zero(x)
@eval Base.:*(x::$T, ::Zero) = zero(x)
end

Base.zero(::AbstractDifferential) = Zero()

Base.:+(a::One, b::One) = extern(a) + extern(b)
Base.:*(::One, ::One) = One()
Expand Down
14 changes: 7 additions & 7 deletions test/differentials/zero.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
@testset "Zero" begin
z = Zero()
@test extern(z) === false
@test z + z == z
@test z + 1 == 1
@test 1 + z == 1
@test z * z == z
@test z * 1 == z
@test 1 * z == z
@test z + z === z
@test z + 1 === 1
@test 1 + z === 1
@test z * z === z
@test z * 1 === 0
@test 1 * z === 0
for x in z
@test x === z
end
@test broadcastable(z) isa Ref{Zero}
@test conj(z) == z
@test conj(z) === z
end
18 changes: 13 additions & 5 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ cool(x, y) = x + y + 1
dummy_identity(x) = x
@scalar_rule(dummy_identity(x), One())

nice(x) = 1
@scalar_rule(nice(x), Zero())

#######

_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
Expand All @@ -31,11 +34,16 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@test cool_methods == only_methods

frx, cool_pushforward = frule(cool, 1, dself, 1)
@test frx == 2
@test cool_pushforward == 1
@test frx === 2
@test cool_pushforward === 1
rrx, cool_pullback = rrule(cool, 1)
self, rr1 = cool_pullback(1)
@test self == NO_FIELDS
@test rrx == 2
@test rr1 == 1
@test self === NO_FIELDS
@test rrx === 2
@test rr1 === 1

frx, nice_pushforward = frule(nice, 1, dself, 1)
@test nice_pushforward === 0
rrx, nice_pullback = rrule(nice, 1)
@test (NO_FIELDS, 0) === nice_pullback(1)
end