@@ -6,21 +6,49 @@ if MPS.is_supported(device())
66 # test that unsupported configurations error properly
77 N = 20
88 function test_matmul (inT, outT; vec_b= false , alg= :auto )
9- a = MtlArray (rand (inT, N, N))
10- b = MtlArray (rand (inT, vec_b ? (N,) : (N, N)))
11- c = fill! (similar (b, outT), zero (outT))
9+ a = inT <: Integer ? inT .(rand (- 5 : 5 , N,N)) : rand (inT, N, N)
1210
13- @with (Metal. matmul_alg => alg) mul! (c,a,b)
11+ bdims = vec_b ? (N,) : (N, N)
12+ b = inT <: Integer ? inT .(rand (- 5 : 5 , bdims)) : rand (inT, bdims)
13+
14+ ma = MtlArray (a)
15+ mb = MtlArray (b)
16+ mc = fill! (similar (mb, outT), zero (outT))
17+
18+ @with (Metal. matmul_alg => alg) mul! (mc,ma,mb)
19+
20+ return all ((outT .(a)* outT .(b)) .≈ Array (mc))
1421 end
1522
16- # Unsupported for MPS and MPSGraph
1723 for vec_b in (true , false )
18- @test_throws " Matrix multiplication algorithm `:MPS`" test_matmul (Int8, Int16; vec_b, alg= :MPS )
19- @test_throws " Matrix multiplication algorithm `:MPSGraph`" test_matmul (Int8, Int16; vec_b, alg= :MPSGraph )
24+ @testset let vec_b = vec_b
25+ # Unsupported for MPS and MPSGraph
26+ @test_throws " Matrix-$(vec_b ? " Vector" : " Matrix" ) multiplication algorithm `:MPS`" test_matmul (Int8, Int16; vec_b, alg= :MPS )
27+ @test_throws " Matrix-$(vec_b ? " Vector" : " Matrix" ) multiplication algorithm `:MPSGraph`" test_matmul (Int8, Int16; vec_b, alg= :MPSGraph )
2028
2129 # Invalid algorithm Symbol
2230 @test_throws " :bad is not a valid matmul algorithm." test_matmul (Int8, Int16; vec_b, alg= :bad )
2331 @test_throws " :bad is not a valid matmul algorithm." test_matmul (Float16, Float16; vec_b, alg= :bad )
32+
33+ # :auto
34+ @test test_matmul (Int32, Int32; vec_b) # fallback to GPUArrays
35+ @test test_matmul (Int8, Float32; vec_b) # should use MPS
36+ @test test_matmul (Float16, Float32; vec_b) # should use MPSGraph on M1/M2
37+
38+ # :MPS
39+ mpsInT = vec_b ? Float32 : Int16
40+ @test test_matmul (mpsInT, Float32; vec_b, alg= :MPS )
41+ @test test_matmul (Float16, Float32; vec_b, alg= :MPS )
42+
43+ # :MPSGraph
44+ @test test_matmul (Int8, Float32; vec_b, alg= :MPSGraph )
45+ @test test_matmul (Float16, Float32; vec_b, alg= :MPSGraph )
46+
47+ # :GPUArrays
48+ @test test_matmul (Int32, Int32; vec_b, alg= :GPUArrays )
49+ @test test_matmul (Int8, Float32; vec_b, alg= :GPUArrays )
50+ @test test_matmul (Float16, Float32; vec_b, alg= :GPUArrays )
51+ end
2452 end
2553end
2654
0 commit comments