109109
110110# Metal's pivoting sequence needs to be iterated sequentially...
111111# TODO : figure out a GPU-compatible way to get the permutation matrix
112- LinearAlgebra. ipiv2perm (v:: MtlVector{T} , maxi:: Integer ) where T =
112+ LinearAlgebra. ipiv2perm (v:: MtlVector , maxi:: Integer ) =
113113 LinearAlgebra. ipiv2perm (Array (v), maxi)
114+ LinearAlgebra. ipiv2perm (v:: MtlVector{<:Any,MTL.CPUStorage} , maxi:: Integer ) =
115+ LinearAlgebra. ipiv2perm (unsafe_wrap (Array, v), maxi)
114116
115117@autoreleasepool function LinearAlgebra. lu (A:: MtlMatrix{T} ;
116118 check:: Bool = true ) where {T<: MtlFloat }
@@ -129,7 +131,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =
129131 end
130132
131133 P = similar (A, UInt32, 1 , min (N, M))
132- status = MtlArray {MPSMatrixDecompositionStatus} (undef)
134+ status = MtlArray {MPSMatrixDecompositionStatus,0,SharedStorage } (undef)
133135
134136 commitAndContinue! (cmdbuf) do cbuf
135137 mps_p = MPSMatrix (P)
@@ -150,7 +152,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =
150152
151153 wait_completed (cmdbuf)
152154
153- status = convert (LinearAlgebra. BlasInt, Metal . @allowscalar status[])
155+ status = convert (LinearAlgebra. BlasInt, status[])
154156 check && checknonsingular (status)
155157
156158 return LinearAlgebra. LU (B, p, status)
187189 end
188190
189191 P = similar (A, UInt32, 1 , min (N, M))
190- status = MtlArray {MPSMatrixDecompositionStatus} (undef)
192+ status = MtlArray {MPSMatrixDecompositionStatus,0,SharedStorage } (undef)
191193
192194 commitAndContinue! (cmdbuf) do cbuf
193195 mps_p = MPSMatrix (P)
205207
206208 wait_completed (cmdbuf)
207209
208- status = convert (LinearAlgebra. BlasInt, Metal . @allowscalar status[])
210+ status = convert (LinearAlgebra. BlasInt, status[])
209211 check && _check_lu_success (status, allowsingular)
210212
211213 return LinearAlgebra. LU (A, p, status)
0 commit comments