Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.5.0"
version = "0.5.1"

[compat]
julia = "^1.0"
Expand Down
9 changes: 5 additions & 4 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ end
Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
for T in (:Any,)
@eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b
@eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b)
# we want to eagerly compute the result when thunk meets other types
@eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b)

@eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
end

################## Composite ##############################################################
Expand Down
3 changes: 2 additions & 1 deletion src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ function propagation_expr(Δs, ∂s)
# This is basically Δs ⋅ ∂s
∂s = map(esc, ∂s)

∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), length(∂s))
# this is neccssary since we want to eagerly evaluate the result
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
return :(+($(∂_mul_Δs...)))
end

Expand Down
18 changes: 13 additions & 5 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ cool(x, y) = x + y + 1
dummy_identity(x) = x
@scalar_rule(dummy_identity(x), One())

nice(x) = 1
@scalar_rule(nice(x), Zero())

#######

_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
Expand All @@ -31,11 +34,16 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@test cool_methods == only_methods

frx, cool_pushforward = frule(cool, 1, dself, 1)
@test frx == 2
@test cool_pushforward == 1
@test frx === 2
@test cool_pushforward === 1
rrx, cool_pullback = rrule(cool, 1)
self, rr1 = cool_pullback(1)
@test self == NO_FIELDS
@test rrx == 2
@test rr1 == 1
@test self === NO_FIELDS
@test rrx === 2
@test rr1 === 1

frx, nice_pushforward = frule(nice, 1, dself, 1)
@test nice_pushforward === 0
rrx, nice_pullback = rrule(nice, 1)
@test (NO_FIELDS, 0) === nice_pullback(1)
end