Skip to content

Commit 2c009d1

Browse files
authored
Fix rrules for Symmetric and Diagonal constructors (#23)
Currently these definitions are extending `rrule(::typeof(T), x)` where `T` is a type. However, `typeof(Diagonal) == UnionAll`, which means this is not defining the method it looks like it might be defining. The only reason this worked when originally implemented was that one of the `rrule` definitions was for `rrule(UnionAll, Matrix)` and the other for `rrule(UnionAll, Vector)`, so dispatch still worked. This replaces these problematic `::typeof(T)`s with `::Type{<:T}`.
1 parent f783dff commit 2c009d1

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

src/rules/linalg/diagonal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
rrule(::typeof(Diagonal), d::AbstractVector) = Diagonal(d), Rule(diag)
1+
rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag)
22
rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal)

src/rules/linalg/symmetric.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
rrule(::typeof(Symmetric), A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)
1+
rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)
22

33
_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
44
_symmetric_back(ΔΩ::Union{Diagonal, UpperTriangular}) = ΔΩ

test/rules/linalg/diagonal.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
@testset "Diagonal" begin
33
rng, N = MersenneTwister(123456), 3
44
rrule_test(Diagonal, randn(rng, N, N), (randn(rng, N), randn(rng, N)))
5-
rrule_test(Diagonal, Diagonal(randn(rng, N)), (randn(rng, N), randn(rng, N)))
5+
D = Diagonal(randn(rng, N))
6+
rrule_test(Diagonal, D, (randn(rng, N), randn(rng, N)))
7+
# Concrete type instead of UnionAll
8+
rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N)))
69
end
710
@testset "diag" begin
811
rng, N = MersenneTwister(123456), 7

0 commit comments

Comments
 (0)