diff --git a/src/lib/array.jl b/src/lib/array.jl index 4b8f90609..0d1b59757 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -432,6 +432,12 @@ end @adjoint LinearAlgebra.UnitLowerTriangular(A) = UnitLowerTriangular(A), Δ->(UnitLowerTriangular(Δ)-I,) @adjoint LinearAlgebra.UnitUpperTriangular(A) = UnitUpperTriangular(A), Δ->(UnitUpperTriangular(Δ)-I,) +@adjoint literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl, y -> Tridiagonal(dl, zeros(length(d)), zeros(length(du)),) +@adjoint literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d, y -> Tridiagonal(zeros(length(dl)), d, zeros(length(du)),) +@adjoint literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl, y -> Tridiagonal(zeros(length(dl)), zeros(length(d), du),) +@adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end, 1:end-1]), diag(p̄), diag(p̄[1:end-1, 2:end])) + + # This is basically a hack while we don't have a working `ldiv!`. @adjoint function \(A::Cholesky, B::AbstractVecOrMat) Y, back = Zygote.pullback((U, B)->U \ (U' \ B), A.U, B)