|
32 | 32 |
|
33 | 33 | # Supported values are :auto, :MPS, :MPSGraph, and :GPUArrays |
34 | 34 | const matmul_alg = ScopedValue(:auto) |
35 | | -matmul_alg_error(alg, inT, outT) = error("Matrix multiplication algorithm `:$alg` is not supported for input eltype $inT and output eltype $outT.") |
| 35 | +matmul_alg_error(alg, inT, outT, vec) = error("Matrix-$(vec ? "Vector" : "Matrix") multiplication algorithm `:$alg` is not supported for input eltype $inT and output eltype $outT.") |
36 | 36 |
|
37 | 37 | LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) = |
38 | 38 | LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) |
@@ -64,10 +64,10 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri |
64 | 64 | mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) |
65 | 65 | # If possible, dispatch to MPSGraphs, then performance shaders |
66 | 66 | if alg === :MPSGraph || (alg === :auto && mpsgraph_supported && !should_use_MPS(A, B, C)) |
67 | | - mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C)) |
| 67 | + mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C), false) |
68 | 68 | graph_matmul!(C, A, B, alpha, beta, transA, transB) |
69 | 69 | elseif alg === :MPS || (alg === :auto && mps_supported) |
70 | | - mps_supported || matmul_alg_error(alg, eltype(A), eltype(C)) |
| 70 | + mps_supported || matmul_alg_error(alg, eltype(A), eltype(C), false) |
71 | 71 | matmul!(C, A, B, alpha, beta, transA, transB) |
72 | 72 | elseif alg === :GPUArrays || alg === :auto |
73 | 73 | GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) |
@@ -106,10 +106,10 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B |
106 | 106 | mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES) |
107 | 107 | # If possible, dispatch to MPSGraphs, then performance shaders |
108 | 108 | if alg === :MPSGraph || (alg === :auto && mpsgraph_supported) |
109 | | - mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C)) |
| 109 | + mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C), true) |
110 | 110 | graph_matvecmul!(C, A, B, alpha, beta, transA) |
111 | 111 | elseif alg === :MPS || (alg === :auto && mps_supported) |
112 | | - mps_supported || matmul_alg_error(alg, eltype(A), eltype(C)) |
| 112 | + mps_supported || matmul_alg_error(alg, eltype(A), eltype(C), true) |
113 | 113 | matvecmul!(C, A, B, alpha, beta, transA) |
114 | 114 | elseif alg === :GPUArrays || alg === :auto |
115 | 115 | GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta) |
|
0 commit comments