8
8
9
9
function rrule (T:: Type{<:LinearAlgebra.HermOrSym} , A:: AbstractMatrix , uplo)
10
10
Ω = T (A, uplo)
11
- function HermOrSym_pullback (ΔΩ)
12
- return (NO_FIELDS, _symherm_back (typeof (Ω), ΔΩ, Ω . uplo), DoesNotExist ())
11
+ @inline function HermOrSym_pullback (ΔΩ)
12
+ return (NO_FIELDS, _symherm_back (typeof (Ω), ΔΩ, uplo), DoesNotExist ())
13
13
end
14
14
return Ω, HermOrSym_pullback
15
15
end
@@ -26,7 +26,7 @@ function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
26
26
TA = _symhermtype (A)
27
27
T∂A = TA{eltype (ΔΩ),typeof (ΔΩ)}
28
28
uplo = A. uplo
29
- ∂A = T∂A (_symherm_back (A , ΔΩ, uplo), uplo)
29
+ ∂A = T∂A (_symherm_back (typeof (A) , ΔΩ, Symbol ( uplo) ), uplo)
30
30
return NO_FIELDS, ∂A
31
31
end
32
32
return TM (A), Matrix_pullback
@@ -44,33 +44,46 @@ function _symherm_forward(A, ΔA)
44
44
end
45
45
46
46
# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
47
- _symherm_back (:: Type{<:Symmetric} , ΔΩ, uplo) = _symmetric_back (ΔΩ, uplo)
48
- function _symherm_back (:: Type{<:Hermitian} , ΔΩ:: AbstractMatrix{<:Real} , uplo)
49
- return _symmetric_back (ΔΩ, uplo)
47
+ @inline function _symherm_back (:: Type{T} , ΔΩ, uplo:: Symbol ) where {T}
48
+ if T <: Symmetric
49
+ return _symmetric_back (ΔΩ, uplo)
50
+ elseif T <: Hermitian
51
+ if ΔΩ isa AbstractMatrix{<: Real }
52
+ return _symmetric_back (ΔΩ, uplo)
53
+ else
54
+ return _hermitian_back (ΔΩ, uplo)
55
+ end
56
+ end
57
+ error ()
50
58
end
51
- _symherm_back (:: Type{<:Hermitian} , ΔΩ, uplo) = _hermitian_back (ΔΩ, uplo)
52
- _symherm_back (Ω, ΔΩ, uplo) = _symherm_back (typeof (Ω), ΔΩ, uplo)
53
59
54
- function _symmetric_back (ΔΩ, uplo)
60
+ @inline function _symmetric_back (ΔΩ, uplo:: Symbol )
61
+ if ΔΩ isa Diagonal
62
+ return ΔΩ
63
+ elseif ΔΩ isa LinearAlgebra. AbstractTriangular
64
+ if istriu (ΔΩ)
65
+ return Matrix (uplo === :U ? ΔΩ : transpose (ΔΩ))
66
+ else
67
+ return Matrix (uplo === :U ? transpose (ΔΩ) : ΔΩ)
68
+ end
69
+ end
55
70
L, U, D = LowerTriangular (ΔΩ), UpperTriangular (ΔΩ), Diagonal (ΔΩ)
56
- return uplo == ' U ' ? U .+ transpose (L) - D : L .+ transpose (U) - D
71
+ return uplo === :U ? U .+ transpose (L) - D : L .+ transpose (U) - D
57
72
end
58
- _symmetric_back (ΔΩ:: Diagonal , uplo) = ΔΩ
59
- _symmetric_back (ΔΩ:: UpperTriangular , uplo) = Matrix (uplo == ' U' ? ΔΩ : transpose (ΔΩ))
60
- _symmetric_back (ΔΩ:: LowerTriangular , uplo) = Matrix (uplo == ' U' ? transpose (ΔΩ) : ΔΩ)
61
73
62
- function _hermitian_back (ΔΩ, uplo)
63
- L, U, rD = LowerTriangular (ΔΩ), UpperTriangular (ΔΩ), real .(Diagonal (ΔΩ))
64
- return uplo == ' U' ? U .+ L' - rD : L .+ U' - rD
65
- end
66
- _hermitian_back (ΔΩ:: Diagonal , uplo) = real .(ΔΩ)
67
- function _hermitian_back (ΔΩ:: LinearAlgebra.AbstractTriangular , uplo)
68
- ∂UL = ΔΩ .- Diagonal (_extract_imag .(diag (ΔΩ)))
69
- return if istriu (ΔΩ)
70
- return Matrix (uplo == ' U' ? ∂UL : ∂UL' )
71
- else
72
- return Matrix (uplo == ' U' ? ∂UL' : ∂UL)
74
+ @inline function _hermitian_back (ΔΩ, uplo:: Symbol )
75
+ if ΔΩ isa Diagonal
76
+ return real .(ΔΩ)
77
+ elseif ΔΩ isa LinearAlgebra. AbstractTriangular
78
+ ∂UL = ΔΩ .- Diagonal (_extract_imag .(diag (ΔΩ)))
79
+ if istriu (ΔΩ)
80
+ return Matrix (uplo === :U ? ∂UL : ∂UL' )
81
+ else
82
+ return Matrix (uplo === :U ? ∂UL' : ∂UL)
83
+ end
73
84
end
85
+ L, U, rD = LowerTriangular (ΔΩ), UpperTriangular (ΔΩ), real .(Diagonal (ΔΩ))
86
+ return uplo === :U ? U .+ L' - rD : L .+ U' - rD
74
87
end
75
88
76
89
# ####
0 commit comments