Skip to content

Commit 762a5e5

Browse files
committed
make work for new ChainRules v0.1
1 parent b54a3fc commit 762a5e5

File tree

3 files changed

+2
-76
lines changed

3 files changed

+2
-76
lines changed

src/AbstractChainRules.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module AbstractChainRules
2-
32
using Cassette
43
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
54

65
export AbstractRule, Rule, frule, rrule
6+
export @scalar_rule, @thunk
7+
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
78

89
include("differentials.jl")
910
include("rules.jl")

src/rules.jl

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -112,32 +112,6 @@ See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
112112
"""
113113
store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...)))
114114

115-
# Special purpose updating for operations which can be done in-place. This function is
116-
# just internal and free-form; it is not a method of `accumulate!` directly as it does
117-
# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`.
118-
# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular
119-
# rule.
120-
121-
_update!(x, y) = x + y
122-
_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y
123-
124-
_update!(x, ::Zero) = x
125-
_update!(::Zero, y) = y
126-
_update!(::Zero, ::Zero) = Zero()
127-
128-
function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns
129-
return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns))
130-
end
131-
132-
function _update!(x::NamedTuple, y, p::Symbol)
133-
new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),))
134-
return merge(x, new)
135-
end
136-
137-
function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns
138-
return _update!(x, getproperty(y, p), p)
139-
end
140-
141115
#####
142116
##### `Rule`
143117
#####
@@ -377,23 +351,6 @@ See also: [`frule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref)
377351
"""
378352
rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
379353

380-
@noinline function _throw_checked_rrule_error(f, args...; kwargs...)
381-
io = IOBuffer()
382-
print(io, "can't differentiate `", f, '(')
383-
join(io, map(arg->string("::", typeof(arg)), args), ", ")
384-
if !isempty(kwargs)
385-
print(io, ";")
386-
join(io, map(((k, v),)->string(k, "=", v), kwargs), ", ")
387-
end
388-
print(io, ")`; no matching `rrule` is defined")
389-
throw(ArgumentError(String(take!(io))))
390-
end
391-
392-
function _checked_rrule(f, args...; kwargs...)
393-
r = rrule(f, args...; kwargs...)
394-
r isa Nothing && _throw_checked_rrule_error(f, args...; kwargs...)
395-
return r
396-
end
397354

398355
#####
399356
##### macros

test/rules.jl

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,36 +42,4 @@ dummy_identity(x) = x
4242
@test rule[1] == rule
4343
@test_throws BoundsError rule[2]
4444
end
45-
@testset "helper functions" begin
46-
# Hits fallback, since we can't update `Diagonal`s in place
47-
X = Diagonal([1, 1])
48-
Y = copy(X)
49-
@test AbstractChainRules._update!(X, [1 2; 3 4]) == [2 2; 3 5]
50-
@test X == Y # no change to X
51-
52-
X = [1 2; 3 4]
53-
Y = copy(X)
54-
@test AbstractChainRules._update!(X, Diagonal([1, 1])) == [2 2; 3 5]
55-
@test X != Y # X has been updated
56-
57-
# Reusing above X
58-
@test AbstractChainRules._update!(X, Zero()) === X
59-
@test AbstractChainRules._update!(Zero(), X) === X
60-
@test AbstractChainRules._update!(Zero(), Zero()) === Zero()
61-
62-
X = (A=[1 0; 0 1], B=[2 2; 2 2])
63-
Y = deepcopy(X)
64-
@test AbstractChainRules._update!(X, Y) == (A=[2 0; 0 2], B=[4 4; 4 4])
65-
@test X.A != Y.A
66-
@test X.B != Y.B
67-
68-
try
69-
# We defined a 2-arg method for `cool` but no `rrule`
70-
AbstractChainRules._checked_rrule(cool, 1.0, 2.0)
71-
catch e
72-
@test e isa ArgumentError
73-
@test e.msg == "can't differentiate `cool(::Float64, ::Float64)`; no " *
74-
"matching `rrule` is defined"
75-
end
76-
end
7745
end

0 commit comments

Comments
 (0)