Skip to content

Commit 9267b79

Browse files
authored
Merge pull request #24 from devmotion/trackedvecormat
Do not redefine TrackedVecOrMat
2 parents 87aeca7 + 6db9168 commit 9267b79

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

src/common.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,14 @@ function LinearAlgebra.logdet(C::Cholesky{<:Tracker.TrackedReal, <:Tracker.Track
8989
end
9090

9191
# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.
92-
const TrackedVecOrMat = Union{Tracker.TrackedVector, Tracker.TrackedMatrix}
9392
zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B
94-
function zygote_ldiv(A::Tracker.TrackedMatrix, B::TrackedVecOrMat)
93+
function zygote_ldiv(A::Tracker.TrackedMatrix, B::Tracker.TrackedVecOrMat)
9594
return Tracker.track(zygote_ldiv, A, B)
9695
end
9796
function zygote_ldiv(A::Tracker.TrackedMatrix, B::AbstractVecOrMat)
9897
return Tracker.track(zygote_ldiv, A, B)
9998
end
100-
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = Tracker.track(zygote_ldiv, A, B)
99+
zygote_ldiv(A::AbstractMatrix, B::Tracker.TrackedVecOrMat) = Tracker.track(zygote_ldiv, A, B)
101100
Tracker.@grad function zygote_ldiv(A, B)
102101
Y, back = Zygote.pullback(\, Tracker.data(A), Tracker.data(B))
103102
return Y, Δ->back(Tracker.data(Δ))

0 commit comments

Comments
 (0)