@@ -2,6 +2,24 @@ using LinearAlgebra
22using  LinearAlgebra:  MulAddMul, wrap
33using  . MPS
44using  . MPS:  MPS_VALID_MATMUL_TYPES, MPS_VALID_MATVECMUL_TYPES, MtlFloat
5+ using  . MPSGraphs:  MPSGRAPH_VALID_MATMUL_TYPES, MPSGRAPH_VALID_MATVECMUL_TYPES,
6+                   graph_matmul!, graph_matvecmul!
7+ 
8+ @inline  function  supports_mps_matmul (A,B,C, valid_types)
9+     MPS. is_supported (device (A)) && 
10+         eltype (A) ==  eltype (B) && 
11+         (eltype (A), eltype (C)) in  valid_types
12+ end 
13+ 
14+ @inline  function  supports_mpsgraph_matmul (A,B,C, valid_types)
15+     MPS. is_supported (device (A)) && 
16+         eltype (A) ==  eltype (B) && 
17+         (eltype (A), eltype (C)) in  valid_types && 
18+         #  TODO : remove this limitation
19+         A. offset ==  0  && 
20+         B. offset ==  0  && 
21+         C. offset ==  0 
22+ end 
523
624LinearAlgebra. generic_matmatmul! (C:: MtlMatrix , tA, tB, A:: MtlMatrix , B:: MtlMatrix , _add:: MulAddMul ) = 
725    LinearAlgebra. generic_matmatmul! (C, tA, tB, A, B, _add. alpha, _add. beta)
@@ -28,13 +46,10 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
2846    transA =  tA ==  ' T'   ||  tA ==  ' C' 
2947    transB =  tB ==  ' T'   ||  tB ==  ' C' 
3048
31-     typA =  eltype (A)
32-     typB =  eltype (B)
33-     typC =  eltype (C)
34- 
35-     #  If possible, dispatch to performance shaders
36-     if  MPS. is_supported (device ()) && 
37-             typA ==  typB &&  (typA, typC) in  MPS_VALID_MATMUL_TYPES
49+     #  If possible, dispatch to MPSGraphs, then performance shaders
50+     if  supports_mpsgraph_matmul (A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
51+         graph_matmul! (C, A, B, alpha, beta, transA, transB)
52+     elseif  supports_mps_matmul (A, B, C, MPS_VALID_MATMUL_TYPES)
3853        matmul! (C, A, B, alpha, beta, transA, transB)
3954    else 
4055        GPUArrays. generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
@@ -66,13 +81,10 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
6681
6782    transA =  tA ==  ' T'   ||  tA ==  ' C' 
6883
69-     typA =  eltype (A)
70-     typB =  eltype (B)
71-     typC =  eltype (C)
72- 
73-     #  If possible, dispatch to performance shaders
74-     if  MPS. is_supported (device ()) && 
75-             typA ==  typB &&  (typA, typC) in  MPS_VALID_MATVECMUL_TYPES
84+     #  If possible, dispatch to MPSGraphs, then performance shaders
85+     if  supports_mpsgraph_matmul (A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
86+         graph_matvecmul! (C, A, B, alpha, beta, transA)
87+     elseif  supports_mps_matmul (A, B, C, MPS_VALID_MATVECMUL_TYPES)
7688        matvecmul! (C, A, B, alpha, beta, transA)
7789    else 
7890        GPUArrays. generic_matmatmul! (C, wrap (A, tA), B, alpha, beta)
0 commit comments