@@ -2,6 +2,24 @@ using LinearAlgebra
22using LinearAlgebra: MulAddMul, wrap
33using . MPS
44using . MPS: MPS_VALID_MATMUL_TYPES, MPS_VALID_MATVECMUL_TYPES, MtlFloat
5+ using . MPSGraphs: MPSGRAPH_VALID_MATMUL_TYPES, MPSGRAPH_VALID_MATVECMUL_TYPES,
6+ graph_matmul!, graph_matvecmul!
7+
8+ @inline function supports_mps_matmul (A, B, C, valid_types)
9+ MPS. is_supported (device (A)) &&
10+ eltype (A) == eltype (B) &&
11+ (eltype (A), eltype (C)) in valid_types
12+ end
13+
14+ @inline function supports_mpsgraph_matmul (A, B, C, valid_types)
15+ MPS. is_supported (device (A)) &&
16+ eltype (A) == eltype (B) &&
17+ (eltype (A), eltype (C)) in valid_types &&
18+ # TODO : remove this limitation
19+ A. offset == 0 &&
20+ B. offset == 0 &&
21+ C. offset == 0
22+ end
523
624LinearAlgebra. generic_matmatmul! (C:: MtlMatrix , tA, tB, A:: MtlMatrix , B:: MtlMatrix , _add:: MulAddMul ) =
725 LinearAlgebra. generic_matmatmul! (C, tA, tB, A, B, _add. alpha, _add. beta)
@@ -28,13 +46,10 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
2846 transA = tA == ' T' || tA == ' C'
2947 transB = tB == ' T' || tB == ' C'
3048
31- typA = eltype (A)
32- typB = eltype (B)
33- typC = eltype (C)
34-
35- # If possible, dispatch to performance shaders
36- if MPS. is_supported (device ()) &&
37- typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES
49+ # If possible, dispatch to MPSGraphs, then performance shaders
50+ if supports_mpsgraph_matmul (A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
51+ graph_matmul! (C, A, B, alpha, beta, transA, transB)
52+ elseif supports_mps_matmul (A, B, C, MPS_VALID_MATMUL_TYPES)
3853 matmul! (C, A, B, alpha, beta, transA, transB)
3954 else
4055 GPUArrays. generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
@@ -66,13 +81,10 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
6681
6782 transA = tA == ' T' || tA == ' C'
6883
69- typA = eltype (A)
70- typB = eltype (B)
71- typC = eltype (C)
72-
73- # If possible, dispatch to performance shaders
74- if MPS. is_supported (device ()) &&
75- typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES
84+ # If possible, dispatch to MPSGraphs, then performance shaders
85+ if supports_mpsgraph_matmul (A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
86+ graph_matvecmul! (C, A, B, alpha, beta, transA)
87+ elseif supports_mps_matmul (A, B, C, MPS_VALID_MATVECMUL_TYPES)
7688 matvecmul! (C, A, B, alpha, beta, transA)
7789 else
7890 GPUArrays. generic_matmatmul! (C, wrap (A, tA), B, alpha, beta)
0 commit comments