Skip to content

Commit 5b13ffb

Browse files
authored
improve inference in LinearAlgebra/symmetric (#371)
* improve inference in LinearAlgebra/symmetric * fix eigen rules and adapt check_inferred * add comment, bump patch version
1 parent adcaa91 commit 5b13ffb

File tree

4 files changed

+54
-29
lines changed

4 files changed

+54
-29
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.51"
3+
version = "0.7.52"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{
307307
hermA = Hermitian(A)
308308
∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV)
309309
∂hermA = eigen_rev!(hermA, λ, V, Δλ, ∂V)
310-
∂Atriu = _symherm_back(typeof(hermA), ∂hermA, hermA.uplo)
310+
∂Atriu = _symherm_back(typeof(hermA), ∂hermA, Symbol(hermA.uplo))
311311
∂A = ∂Atriu isa AbstractTriangular ? triu!(∂Atriu.data) : ∂Atriu
312312
elseif ΔV isa AbstractZero
313313
∂K = Diagonal(Δλ)

src/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ end
88

99
function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
1010
Ω = T(A, uplo)
11-
function HermOrSym_pullback(ΔΩ)
12-
return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, Ω.uplo), DoesNotExist())
11+
@inline function HermOrSym_pullback(ΔΩ)
12+
return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, uplo), DoesNotExist())
1313
end
1414
return Ω, HermOrSym_pullback
1515
end
@@ -26,7 +26,7 @@ function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
2626
TA = _symhermtype(A)
2727
T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)}
2828
uplo = A.uplo
29-
∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo)
29+
∂A = T∂A(_symherm_back(typeof(A), ΔΩ, Symbol(uplo)), uplo)
3030
return NO_FIELDS, ∂A
3131
end
3232
return TM(A), Matrix_pullback
@@ -44,33 +44,46 @@ function _symherm_forward(A, ΔA)
4444
end
4545

4646
# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
47-
_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo)
48-
function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo)
49-
return _symmetric_back(ΔΩ, uplo)
47+
@inline function _symherm_back(::Type{T}, ΔΩ, uplo::Symbol) where {T}
48+
if T <: Symmetric
49+
return _symmetric_back(ΔΩ, uplo)
50+
elseif T <: Hermitian
51+
if ΔΩ isa AbstractMatrix{<:Real}
52+
return _symmetric_back(ΔΩ, uplo)
53+
else
54+
return _hermitian_back(ΔΩ, uplo)
55+
end
56+
end
57+
error()
5058
end
51-
_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo)
52-
_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo)
5359

54-
function _symmetric_back(ΔΩ, uplo)
60+
@inline function _symmetric_back(ΔΩ, uplo::Symbol)
61+
if ΔΩ isa Diagonal
62+
return ΔΩ
63+
elseif ΔΩ isa LinearAlgebra.AbstractTriangular
64+
if istriu(ΔΩ)
65+
return Matrix(uplo === :U ? ΔΩ : transpose(ΔΩ))
66+
else
67+
return Matrix(uplo === :U ? transpose(ΔΩ) : ΔΩ)
68+
end
69+
end
5570
L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ)
56-
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
71+
return uplo === :U ? U .+ transpose(L) - D : L .+ transpose(U) - D
5772
end
58-
_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ
59-
_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ))
60-
_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ)
6173

62-
function _hermitian_back(ΔΩ, uplo)
63-
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
64-
return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD
65-
end
66-
_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ)
67-
function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
68-
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
69-
return if istriu(ΔΩ)
70-
return Matrix(uplo == 'U' ? ∂UL : ∂UL')
71-
else
72-
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
74+
@inline function _hermitian_back(ΔΩ, uplo::Symbol)
75+
if ΔΩ isa Diagonal
76+
return real.(ΔΩ)
77+
elseif ΔΩ isa LinearAlgebra.AbstractTriangular
78+
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
79+
if istriu(ΔΩ)
80+
return Matrix(uplo === :U ? ∂UL : ∂UL')
81+
else
82+
return Matrix(uplo === :U ? ∂UL' : ∂UL)
83+
end
7384
end
85+
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
86+
return uplo === :U ? U .+ L' - rD : L .+ U' - rD
7487
end
7588

7689
#####

test/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,34 @@
1818
@testset "rrule" begin
1919
# on old versions of julia this combination doesn't infer but we don't care as
2020
# it infers fine on modern versions.
21-
check_inferred = !(VERSION <= v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian)
21+
check_inferred = !(VERSION < v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian)
2222

2323
x = randn(T, N, N)
2424
∂x = randn(T, N, N)
2525
ΔΩ = randn(T, N, N)
2626
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
2727
rrule_test(
2828
SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing);
29-
check_inferred = check_inferred
29+
# type stability here critically relies on uplo being constant propagated,
30+
# so we need to test this more carefully below
31+
check_inferred=false,
3032
)
33+
if check_inferred
34+
@inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo}
35+
return rrule(SymHerm, x, uplo)[2](ΔΩ)
36+
end)(SymHerm, x, MT(ΔΩ), Val(uplo))
37+
end
3138
end
3239
@testset "back(::Diagonal)" begin
3340
rrule_test(
3441
SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing);
35-
check_inferred = check_inferred
42+
check_inferred=false,
3643
)
44+
if check_inferred
45+
@inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo}
46+
return rrule(SymHerm, x, uplo)[2](ΔΩ)
47+
end)(SymHerm, x, Diagonal(ΔΩ), Val(uplo))
48+
end
3749
end
3850
end
3951
end

0 commit comments

Comments
 (0)