3030 eltype (A) <: AbstractFloat && rows <= 6000 && cols <= 6000 && Metal. supports_family (device (C), MTL. MTLGPUFamilyApple9)
3131end
3232
33+ # Supported values are :auto, :MPS, :MPSGraph, and :GPUArrays
34+ const matmul_alg = ScopedValue (:auto )
35+
3336LinearAlgebra. generic_matmatmul! (C:: MtlMatrix , tA, tB, A:: MtlMatrix , B:: MtlMatrix , _add:: MulAddMul ) =
3437 LinearAlgebra. generic_matmatmul! (C, tA, tB, A, B, _add. alpha, _add. beta)
3538@autoreleasepool function LinearAlgebra. generic_matmatmul! (C:: MtlMatrix , tA, tB,
@@ -55,13 +58,16 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
5558 transA = tA == ' T' || tA == ' C'
5659 transB = tB == ' T' || tB == ' C'
5760
61+ alg = matmul_alg[]
5862 # If possible, dispatch to MPSGraphs, then performance shaders
59- if supports_mpsgraph_matmul (A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) && ! should_use_MPS (A, B, C)
63+ if supports_mpsgraph_matmul (A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) && (alg === :MPSGraph || (alg === :auto && ! should_use_MPS (A, B, C)) )
6064 graph_matmul! (C, A, B, alpha, beta, transA, transB)
61- elseif supports_mps_matmul (A, B, C, MPS_VALID_MATMUL_TYPES) # TODO : Remove once contiguous views are working
65+ elseif supports_mps_matmul (A, B, C, MPS_VALID_MATMUL_TYPES) && (alg === :MPS || alg === :auto )
6266 matmul! (C, A, B, alpha, beta, transA, transB)
63- else
67+ elseif alg === :GPUArrays || alg === :auto
6468 GPUArrays. generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
69+ else
70+ error (" Invalid matmul algorithm and input combination." )
6571 end
6672end
6773
@@ -90,13 +96,16 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
9096
9197 transA = tA == ' T' || tA == ' C'
9298
99+ alg = matmul_alg[]
93100 # If possible, dispatch to MPSGraphs, then performance shaders
94- if supports_mpsgraph_matmul (A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
101+ if supports_mpsgraph_matmul (A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES) && (alg === :MPSGraph || alg === :auto )
95102 graph_matvecmul! (C, A, B, alpha, beta, transA)
96- elseif supports_mps_matmul (A, B, C, MPS_VALID_MATVECMUL_TYPES) # TODO : Remove once contiguous views are working
103+ elseif supports_mps_matmul (A, B, C, MPS_VALID_MATVECMUL_TYPES) && (alg === :MPS || alg === :auto )
97104 matvecmul! (C, A, B, alpha, beta, transA)
98- else
105+ elseif alg === :GPUArrays || alg === :auto
99106 GPUArrays. generic_matmatmul! (C, wrap (A, tA), B, alpha, beta)
107+ else
108+ error (" Invalid matmul algorithm and input combination." )
100109 end
101110end
102111
0 commit comments