Skip to content

Commit a634056

Browse files
committed
Fix & tests
1 parent c039ca9 commit a634056

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

src/linalg.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ end
3232

3333
# Supported values are :auto, :MPS, :MPSGraph, and :GPUArrays
3434
const matmul_alg = ScopedValue(:auto)
35+
matmul_alg_error(alg, inT, outT) = error("Matrix multiplication algorithm `:$alg` is not supported for input eltype $inT and output eltype $outT.")
3536

3637
LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) =
3738
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
@@ -59,15 +60,19 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
5960
transB = tB == 'T' || tB == 'C'
6061

6162
alg = matmul_alg[]
63+
mps_supported = supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES)
64+
mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
6265
# If possible, dispatch to MPSGraphs, then performance shaders
63-
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) && (alg === :MPSGraph || (alg === :auto && !should_use_MPS(A, B, C)))
66+
if alg === :MPSGraph || (alg === :auto && mpsgraph_supported && !should_use_MPS(A, B, C))
67+
mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C))
6468
graph_matmul!(C, A, B, alpha, beta, transA, transB)
65-
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) && (alg === :MPS || alg === :auto)
69+
elseif alg === :MPS || (alg === :auto && mps_supported)
70+
mps_supported || matmul_alg_error(alg, eltype(A), eltype(C))
6671
matmul!(C, A, B, alpha, beta, transA, transB)
6772
elseif alg === :GPUArrays || alg === :auto
6873
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
6974
else
70-
error("Invalid matmul algorithm and input combination.")
75+
error(":$alg is not a valid matmul algorithm. Options are: `:auto`, `:MPS`, `:MPSGraph`, `:GPUArrays`")
7176
end
7277
end
7378

@@ -97,15 +102,19 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
97102
transA = tA == 'T' || tA == 'C'
98103

99104
alg = matmul_alg[]
105+
mps_supported = supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES)
106+
mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
100107
# If possible, dispatch to MPSGraphs, then performance shaders
101-
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES) && (alg === :MPSGraph || alg === :auto)
108+
if alg === :MPSGraph || (alg === :auto && mpsgraph_supported)
109+
mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C))
102110
graph_matvecmul!(C, A, B, alpha, beta, transA)
103-
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES) && (alg === :MPS || alg === :auto)
111+
elseif alg === :MPS || (alg === :auto && mps_supported)
112+
mps_supported || matmul_alg_error(alg, eltype(A), eltype(C))
104113
matvecmul!(C, A, B, alpha, beta, transA)
105114
elseif alg === :GPUArrays || alg === :auto
106115
GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta)
107116
else
108-
error("Invalid matmul algorithm and input combination.")
117+
error(":$alg is not a valid matmul algorithm. Options are: `:auto`, `:MPS`, `:MPSGraph`, `:GPUArrays`")
109118
end
110119
end
111120

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ObjectiveC = "e86c9b32-1129-44ac-8ea0-90d5bb39ded9"
1414
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1515
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
1616
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
17+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1718
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1819
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1920
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

test/linalg.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,28 @@
1-
using LinearAlgebra
1+
using LinearAlgebra, ScopedValues
22

33
if MPS.is_supported(device())
44

5+
@testset "matmul algorithm selection" begin
6+
# test that unsupported configurations error properly
7+
N = 20
8+
function test_matmul(inT, outT; vec_b=false, alg=:auto)
9+
a = MtlArray(rand(inT, N, N))
10+
b = MtlArray(rand(inT, vec_b ? (N,) : (N, N)))
11+
c = fill!(similar(b, outT), zero(outT))
12+
13+
@with (Metal.matmul_alg => alg) mul!(c,a,b)
14+
end
15+
16+
# Unsupported for MPS and MPSGraph
17+
for vec_b in (true, false)
18+
@test_throws "Matrix multiplication algorithm `:MPS`" test_matmul(Int8, Int16; vec_b, alg=:MPS)
19+
@test_throws "Matrix multiplication algorithm `:MPSGraph`" test_matmul(Int8, Int16; vec_b, alg=:MPSGraph)
20+
21+
# Invalid algorithm Symbol
22+
@test_throws ":bad is not a valid matmul algorithm." test_matmul(Int8, Int16; vec_b, alg=:bad)
23+
@test_throws ":bad is not a valid matmul algorithm." test_matmul(Float16, Float16; vec_b, alg=:bad)
24+
end
25+
end
526

627
@testset "test matrix vector multiplication of views" begin
728
N = 20

0 commit comments

Comments
 (0)