@@ -261,5 +261,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
261261
262262    commit! (cmdbuf)
263263
264+     wait_completed (cmdbuf)
265+ 
266+     return  B
267+ end 
268+ 
269+ 
270+ function  LinearAlgebra.:(\ )(A:: LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where  {T<: MtlFloat }
271+     C =  deepcopy (B)
272+     LinearAlgebra. ldiv! (A, C)
273+     return  C
274+ end 
275+ 
276+ 
277+ function  LinearAlgebra. ldiv! (A:: LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where  {T<: MtlFloat }
278+     M, N =  size (B, 1 ), size (B, 2 )
279+     dev =  current_device ()
280+     queue =  global_queue (dev)
281+ 
282+     At =  similar (A. factors)
283+     Bt =  similar (B, (N, M))
284+     P =  reshape ((A. ipiv .-  UInt32 (1 )), (1 , M))
285+     X =  similar (B, (N, M))
286+ 
287+     transpose! (At, A. factors)
288+     transpose! (Bt, B)
289+ 
290+     mps_a =  MPSMatrix (At)
291+     mps_b =  MPSMatrix (Bt)
292+     mps_p =  MPSMatrix (P)
293+     mps_x =  MPSMatrix (X)
294+ 
295+     MTLCommandBuffer (queue) do  cmdbuf
296+         kernel =  MPSMatrixSolveLU (dev, false , M, N)
297+         encode! (cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
298+     end 
299+ 
300+     transpose! (B, X)
301+     return  B
302+ end 
303+ 
304+ 
305+ function  LinearAlgebra. ldiv! (A:: UpperTriangular{T,<:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where  {T<: MtlFloat }
306+     M, N =  size (B, 1 ), size (B, 2 )
307+     dev =  current_device ()
308+     queue =  global_queue (dev)
309+ 
310+     Ad =  MtlMatrix (A' )
311+     Br =  similar (B, (M, M))
312+     X =  similar (Br)
313+ 
314+     transpose! (Br, B)
315+ 
316+     mps_a =  MPSMatrix (Ad)
317+     mps_b =  MPSMatrix (Br)
318+     mps_x =  MPSMatrix (X)
319+ 
320+     buf =  MTLCommandBuffer (queue) do  cmdbuf
321+         kernel =  MPSMatrixSolveTriangular (dev, false , true , false , false , N, M, 1.0 )
322+         encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
323+     end 
324+ 
325+     wait_completed (buf)
326+ 
327+     copy! (B, X)
328+     return  B
329+ end 
330+ 
331+ 
332+ function  LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T,<:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where  {T<: MtlFloat }
333+     M, N =  size (B, 1 ), size (B, 2 )
334+     dev =  current_device ()
335+     queue =  global_queue (dev)
336+ 
337+     Ad =  MtlMatrix (A)
338+     Br =  reshape (B, (M, N))
339+     X =  similar (Br)
340+ 
341+     mps_a =  MPSMatrix (Ad)
342+     mps_b =  MPSMatrix (Br)
343+     mps_x =  MPSMatrix (X)
344+ 
345+ 
346+     buf =  MTLCommandBuffer (queue) do  cmdbuf
347+         kernel =  MPSMatrixSolveTriangular (dev, true , false , false , true , M, N, 1.0 )
348+         encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
349+     end 
350+ 
351+     wait_completed (buf)
352+ 
353+     copy! (Br, X)
354+     return  B
355+ end 
356+ 
357+ 
358+ function  LinearAlgebra. ldiv! (A:: LowerTriangular{T,<:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where  {T<: MtlFloat }
359+     M, N =  size (B, 1 ), size (B, 2 )
360+     dev =  current_device ()
361+     queue =  global_queue (dev)
362+ 
363+     Ad =  MtlMatrix (A)
364+     Br =  reshape (B, (M, N))
365+     X =  similar (Br)
366+ 
367+     mps_a =  MPSMatrix (Ad)
368+     mps_b =  MPSMatrix (Br)
369+     mps_x =  MPSMatrix (X)
370+ 
371+ 
372+     buf =  MTLCommandBuffer (queue) do  cmdbuf
373+         kernel =  MPSMatrixSolveTriangular (dev, true , true , false , false , M, N, 1.0 )
374+         encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
375+     end 
376+ 
377+     wait_completed (buf)
378+ 
379+     copy! (Br, X)
264380    return  B
265381end 
382+ 
383+ 
384+ function  LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T,<:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where  {T<: MtlFloat }
385+     M, N =  size (B, 1 ), size (B, 2 )
386+     dev =  current_device ()
387+     queue =  global_queue (dev)
388+ 
389+     Ad =  MtlMatrix (A)
390+     Br =  reshape (B, (M, N))
391+     X =  similar (Br)
392+ 
393+     mps_a =  MPSMatrix (Ad)
394+     mps_b =  MPSMatrix (Br)
395+     mps_x =  MPSMatrix (X)
396+ 
397+ 
398+     buf =  MTLCommandBuffer (queue) do  cmdbuf
399+         kernel =  MPSMatrixSolveTriangular (dev, true , true , false , true , M, N, 1.0 )
400+         encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
401+     end 
402+ 
403+     wait_completed (buf)
404+ 
405+     copy! (Br, X)
406+     return  B
407+ end 
0 commit comments