|
32 | 32 |
|
33 | 33 | # Supported values are :auto, :MPS, :MPSGraph, and :GPUArrays |
34 | 34 | const matmul_alg = ScopedValue(:auto) |
| 35 | +matmul_alg_error(alg, inT, outT) = error("Matrix multiplication algorithm `:$alg` is not supported for input eltype $inT and output eltype $outT.") |
35 | 36 |
|
36 | 37 | LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) = |
37 | 38 | LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) |
@@ -59,15 +60,19 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri |
59 | 60 | transB = tB == 'T' || tB == 'C' |
60 | 61 |
|
61 | 62 | alg = matmul_alg[] |
| 63 | + mps_supported = supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) |
| 64 | + mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) |
62 | 65 | # If possible, dispatch to MPSGraphs, then performance shaders |
63 | | - if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) && (alg === :MPSGraph || (alg === :auto && !should_use_MPS(A, B, C))) |
| 66 | + if alg === :MPSGraph || (alg === :auto && mpsgraph_supported && !should_use_MPS(A, B, C)) |
| 67 | + mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C)) |
64 | 68 | graph_matmul!(C, A, B, alpha, beta, transA, transB) |
65 | | - elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) && (alg === :MPS || alg === :auto) |
| 69 | + elseif alg === :MPS || (alg === :auto && mps_supported) |
| 70 | + mps_supported || matmul_alg_error(alg, eltype(A), eltype(C)) |
66 | 71 | matmul!(C, A, B, alpha, beta, transA, transB) |
67 | 72 | elseif alg === :GPUArrays || alg === :auto |
68 | 73 | GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) |
69 | 74 | else |
70 | | - error("Invalid matmul algorithm and input combination.") |
| 75 | + error(":$alg is not a valid matmul algorithm. Options are: `:auto`, `:MPS`, `:MPSGraph`, `:GPUArrays`") |
71 | 76 | end |
72 | 77 | end |
73 | 78 |
|
@@ -97,15 +102,19 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B |
97 | 102 | transA = tA == 'T' || tA == 'C' |
98 | 103 |
|
99 | 104 | alg = matmul_alg[] |
| 105 | + mps_supported = supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES) |
| 106 | + mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES) |
100 | 107 | # If possible, dispatch to MPSGraphs, then performance shaders |
101 | | - if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES) && (alg === :MPSGraph || alg === :auto) |
| 108 | + if alg === :MPSGraph || (alg === :auto && mpsgraph_supported) |
| 109 | + mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C)) |
102 | 110 | graph_matvecmul!(C, A, B, alpha, beta, transA) |
103 | | - elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES) && (alg === :MPS || alg === :auto) |
| 111 | + elseif alg === :MPS || (alg === :auto && mps_supported) |
| 112 | + mps_supported || matmul_alg_error(alg, eltype(A), eltype(C)) |
104 | 113 | matvecmul!(C, A, B, alpha, beta, transA) |
105 | 114 | elseif alg === :GPUArrays || alg === :auto |
106 | 115 | GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta) |
107 | 116 | else |
108 | | - error("Invalid matmul algorithm and input combination.") |
| 117 | + error(":$alg is not a valid matmul algorithm. Options are: `:auto`, `:MPS`, `:MPSGraph`, `:GPUArrays`") |
109 | 118 | end |
110 | 119 | end |
111 | 120 |
|
|
0 commit comments