40
40
41
41
# # Linear algebra ##
42
42
43
- LinearAlgebra. UpperTriangular (A:: TrackedMatrix ) = track (UpperTriangular, A)
44
- @grad function LinearAlgebra. UpperTriangular (A:: AbstractMatrix )
45
- return UpperTriangular (data (A)), Δ-> (UpperTriangular (Δ),)
43
+ # Work around https://github.com/FluxML/Tracker.jl/pull/9#issuecomment-480051767
44
+
45
+ upper (A:: AbstractMatrix ) = UpperTriangular (A)
46
+ lower (A:: AbstractMatrix ) = LowerTriangular (A)
47
+ function upper (C:: Cholesky )
48
+ if C. uplo == ' U'
49
+ return upper (C. factors)
50
+ else
51
+ return copy (lower (C. factors)' )
52
+ end
53
+ end
54
+ function lower (C:: Cholesky )
55
+ if C. uplo == ' U'
56
+ return copy (upper (C. factors)' )
57
+ else
58
+ return lower (C. factors)
59
+ end
60
+ end
61
+
62
+ LinearAlgebra. LowerTriangular (A:: TrackedMatrix ) = lower (A)
63
+ lower (A:: TrackedMatrix ) = track (lower, A)
64
+ @grad lower (A) = lower (Tracker. data (A)), ∇ -> (lower (∇),)
65
+
66
+ LinearAlgebra. UpperTriangular (A:: TrackedMatrix ) = upper (A)
67
+ upper (A:: TrackedMatrix ) = track (upper, A)
68
+ @grad upper (A) = upper (Tracker. data (A)), ∇ -> (upper (∇),)
69
+
70
+ function Base. copy (
71
+ A:: TrackedArray {T, 2 , <: Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}} },
72
+ ) where {T <: Real }
73
+ return track (copy, A)
74
+ end
75
+ @grad function Base. copy (
76
+ A:: TrackedArray {T, 2 , <: Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}} },
77
+ ) where {T <: Real }
78
+ return copy (data (A)), ∇ -> (copy (∇),)
46
79
end
47
80
48
81
function LinearAlgebra. cholesky (A:: TrackedMatrix ; check= true )
@@ -57,40 +90,10 @@ function turing_chol(A::AbstractMatrix, check)
57
90
end
58
91
turing_chol (A:: TrackedMatrix , check) = track (turing_chol, A, check)
59
92
@grad function turing_chol (A:: AbstractMatrix , check)
60
- C, back = pullback (unsafe_cholesky , data (A), data (check))
93
+ C, back = pullback (_turing_chol , data (A), data (check))
61
94
return (C. factors, C. info), Δ-> back ((factors= data (Δ[1 ]),))
62
95
end
63
-
64
- unsafe_cholesky (x, check) = cholesky (x, check= check)
65
- @adjoint function unsafe_cholesky (Σ:: Real , check)
66
- C = cholesky (Σ; check= check)
67
- return C, function (Δ:: NamedTuple )
68
- issuccess (C) || return (zero (Σ), nothing )
69
- (Δ. factors[1 , 1 ] / (2 * C. U[1 , 1 ]), nothing )
70
- end
71
- end
72
- @adjoint function unsafe_cholesky (Σ:: Diagonal , check)
73
- C = cholesky (Σ; check= check)
74
- return C, function (Δ:: NamedTuple )
75
- issuccess (C) || (Diagonal (zero (diag (Δ. factors))), nothing )
76
- (Diagonal (diag (Δ. factors) .* inv .(2 .* C. factors. diag)), nothing )
77
- end
78
- end
79
- @adjoint function unsafe_cholesky (Σ:: Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}} , check)
80
- C = cholesky (Σ; check= check)
81
- return C, function (Δ:: NamedTuple )
82
- issuccess (C) || return (zero (Δ. factors), nothing )
83
- U, Ū = C. U, Δ. factors
84
- Σ̄ = Ū * U'
85
- Σ̄ = copytri! (Σ̄, ' U' )
86
- Σ̄ = ldiv! (U, Σ̄)
87
- BLAS. trsm! (' R' , ' U' , ' T' , ' N' , one (eltype (Σ)), U. data, Σ̄)
88
- @inbounds for n in diagind (Σ̄)
89
- Σ̄[n] /= 2
90
- end
91
- return (UpperTriangular (Σ̄), nothing )
92
- end
93
- end
96
+ _turing_chol (x, check) = cholesky (x, check= check)
94
97
95
98
# Specialised logdet for cholesky to target the triangle directly.
96
99
logdet_chol_tri (U:: AbstractMatrix ) = 2 * sum (log, U[diagind (U)])
0 commit comments