Skip to content

Commit ff7c7eb

Browse files
tgymnichmaleadt
andauthored
Use unified memory for scalar indexing of permutation matrices (#313)
Co-authored-by: Tim Besard <[email protected]>
1 parent c8cf84a commit ff7c7eb

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

lib/mps/linalg.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,10 @@ end
109109

110110
# Metal's pivoting sequence needs to be iterated sequentially...
111111
# TODO: figure out a GPU-compatible way to get the permutation matrix
112-
LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =
112+
LinearAlgebra.ipiv2perm(v::MtlVector, maxi::Integer) =
113113
LinearAlgebra.ipiv2perm(Array(v), maxi)
114+
LinearAlgebra.ipiv2perm(v::MtlVector{<:Any,MTL.CPUStorage}, maxi::Integer) =
115+
LinearAlgebra.ipiv2perm(unsafe_wrap(Array, v), maxi)
114116

115117
@autoreleasepool function LinearAlgebra.lu(A::MtlMatrix{T};
116118
check::Bool=true) where {T<:MtlFloat}
@@ -129,7 +131,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =
129131
end
130132

131133
P = similar(A, UInt32, 1, min(N, M))
132-
status = MtlArray{MPSMatrixDecompositionStatus}(undef)
134+
status = MtlArray{MPSMatrixDecompositionStatus,0,SharedStorage}(undef)
133135

134136
commitAndContinue!(cmdbuf) do cbuf
135137
mps_p = MPSMatrix(P)
@@ -150,7 +152,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =
150152

151153
wait_completed(cmdbuf)
152154

153-
status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[])
155+
status = convert(LinearAlgebra.BlasInt, status[])
154156
check && checknonsingular(status)
155157

156158
return LinearAlgebra.LU(B, p, status)
@@ -187,7 +189,7 @@ end
187189
end
188190

189191
P = similar(A, UInt32, 1, min(N, M))
190-
status = MtlArray{MPSMatrixDecompositionStatus}(undef)
192+
status = MtlArray{MPSMatrixDecompositionStatus,0,SharedStorage}(undef)
191193

192194
commitAndContinue!(cmdbuf) do cbuf
193195
mps_p = MPSMatrix(P)
@@ -205,7 +207,7 @@ end
205207

206208
wait_completed(cmdbuf)
207209

208-
status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[])
210+
status = convert(LinearAlgebra.BlasInt, status[])
209211
check && _check_lu_success(status, allowsingular)
210212

211213
return LinearAlgebra.LU(A, p, status)

0 commit comments

Comments
 (0)