@@ -192,110 +192,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192192
193193 commit!(cmdbuf)
194194
195+ wait_completed(cmdbuf)
196+
195197 return B
196198end
197199
198200
201+ function LinearAlgebra.:(\ )(A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
202+ C = deepcopy(B)
203+ LinearAlgebra. ldiv!(A, C)
204+ return C
205+ end
206+
207+
199208function LinearAlgebra. ldiv!(A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
200- orig = size(B)
201- M,N = size(B)[1 ], ndims(B) > 1 ? size(B)[2 ] : 1
209+ M,N = size(B,1 ), size(B,2 )
202210 dev = current_device()
203211 queue = global_queue(dev)
204212
205- B = reshape(B, (N,M))
213+ At = similar(A. factors)
214+ Bt = similar(B, (N,M))
206215 P = reshape((A. ipiv .- UInt32(1 )), (1 ,M))
207- X = similar(B)
216+ X = similar(B, (N,M) )
208217
209- mps_a = MPSMatrix(A. factors)
210- mps_b = MPSMatrix(B)
218+ transpose!(At, A. factors)
219+ transpose!(Bt, B)
220+
221+ mps_a = MPSMatrix(At)
222+ mps_b = MPSMatrix(Bt)
211223 mps_p = MPSMatrix(P)
212224 mps_x = MPSMatrix(X)
213225
214226 MTLCommandBuffer(queue) do cmdbuf
215- kernel = MPSMatrixSolveLU(dev, true , M, N)
227+ kernel = MPSMatrixSolveLU(dev, false , M, N)
216228 encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
217229 end
218230
219- B . = X
220- B = reshape(B, orig)
231+ transpose!(B, X)
232+ return B
221233end
222234
223- function LinearAlgebra. ldiv!(A:: UnitUpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
224- M,N = size(B)
235+
236+ function LinearAlgebra. ldiv!(A:: UpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
237+ M,N = size(B,1 ), size(B,2 )
225238 dev = current_device()
226239 queue = global_queue(dev)
227- cmdbuf = MTLCommandBuffer(queue)
228- enqueue!(cmdbuf)
229240
230- Bh = reshape(B, )
231- X = MtlMatrix{T}(undef, size(B))
241+ Ad = MtlMatrix(A' )
242+ Br = similar(B, (M,M))
243+ X = similar(Br)
232244
233- mps_a = MPSMatrix(A)
234- mps_b = MPSMatrix(Bh) # TODO reshape to matrix if B is a vector
245+ transpose!(Br, B)
246+
247+ mps_a = MPSMatrix(Ad)
248+ mps_b = MPSMatrix(Br)
235249 mps_x = MPSMatrix(X)
236250
237- solve_kernel = MPSMatrixSolveTriangular(dev, false , false , false , true , M, N, 1.0 )
238- encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
239- commit!(cmdbuf)
251+ buf = MTLCommandBuffer(queue) do cmdbuf
252+ kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0)
253+ encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
254+ end
240255
241- return X
256+ wait_completed(buf)
257+
258+ copy!(B, X)
259+ return B
242260end
243261
244- function LinearAlgebra. ldiv!(A:: LowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
245- M,N = size(B)
262+
263+ function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
264+ M,N = size(B,1), size(B,2)
246265 dev = current_device()
247266 queue = global_queue(dev)
248- cmdbuf = MTLCommandBuffer(queue)
249- enqueue!(cmdbuf)
250267
251- X = MtlMatrix{T}(undef, size(B))
268+ Ad = MtlMatrix(A)
269+ Br = reshape(B, (M,N))
270+ X = similar(Br)
252271
253- mps_a = MPSMatrix(A )
254- mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
272+ mps_a = MPSMatrix(Ad )
273+ mps_b = MPSMatrix(Br)
255274 mps_x = MPSMatrix(X)
256275
257- solve_kernel = MPSMatrixSolveTriangular(dev, false , true , false , false , M, N, 1.0 )
258- encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
259- commit!(cmdbuf)
276+
277+ buf = MTLCommandBuffer(queue) do cmdbuf
278+ kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0)
279+ encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
280+ end
260281
261- return X
282+ wait_completed(buf)
283+
284+ copy!(Br, X)
285+ return B
262286end
263287
264- function LinearAlgebra. ldiv!(A:: UnitLowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
265- M,N = size(B)
288+
289+ function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
290+ M,N = size(B,1), size(B,2)
266291 dev = current_device()
267292 queue = global_queue(dev)
268- cmdbuf = MTLCommandBuffer(queue)
269- enqueue!(cmdbuf)
270293
271- X = MtlMatrix{T}(undef, size(B))
294+ Ad = MtlMatrix(A)
295+ Br = reshape(B, (M,N))
296+ X = similar(Br)
272297
273- mps_a = MPSMatrix(A )
274- mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
298+ mps_a = MPSMatrix(Ad )
299+ mps_b = MPSMatrix(Br)
275300 mps_x = MPSMatrix(X)
276301
277- solve_kernel = MPSMatrixSolveTriangular(dev, false , true , false , true , M, N, 1.0 )
278- encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
279- commit!(cmdbuf)
302+
303+ buf = MTLCommandBuffer(queue) do cmdbuf
304+ kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0)
305+ encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
306+ end
307+
308+ wait_completed(buf)
280309
281- return X
310+ copy!(Br, X)
311+ return B
282312end
283313
284- # function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
285- # require_one_based_indexing(A, B)
286- # m, n = size(A)
287- # if m == n
288- # if istril(A)
289- # if istriu(A)
290- # return Diagonal(A) \ B
291- # else
292- # return LowerTriangular(A) \ B
293- # end
294- # end
295- # if istriu(A)
296- # return UpperTriangular(A) \ B
297- # end
298- # return lu(A) \ B
299- # end
300- # return qr(A, ColumnNorm()) \ B
301- # end
314+
315+ function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
316+ M,N = size(B,1), size(B,2)
317+ dev = current_device()
318+ queue = global_queue(dev)
319+
320+ Ad = MtlMatrix(A)
321+ Br = reshape(B, (M,N))
322+ X = similar(Br)
323+
324+ mps_a = MPSMatrix(Ad)
325+ mps_b = MPSMatrix(Br)
326+ mps_x = MPSMatrix(X)
327+
328+
329+ buf = MTLCommandBuffer(queue) do cmdbuf
330+ kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0)
331+ encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
332+ end
333+
334+ wait_completed(buf)
335+
336+ copy!(Br, X)
337+ return B
338+ end
0 commit comments