Skip to content

Commit f1f2c17

Browse files
committed
More tests
1 parent 7a7c003 commit f1f2c17

File tree

1 file changed

+35
-7
lines changed

1 file changed

+35
-7
lines changed

test/linalg.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,49 @@ if MPS.is_supported(device())
66
# test that unsupported configurations error properly
77
N = 20
88
function test_matmul(inT, outT; vec_b=false, alg=:auto)
9-
a = MtlArray(rand(inT, N, N))
10-
b = MtlArray(rand(inT, vec_b ? (N,) : (N, N)))
11-
c = fill!(similar(b, outT), zero(outT))
9+
a = inT <: Integer ? inT.(rand(-5:5, N,N)) : rand(inT, N, N)
1210

13-
@with (Metal.matmul_alg => alg) mul!(c,a,b)
11+
bdims = vec_b ? (N,) : (N, N)
12+
b = inT <: Integer ? inT.(rand(-5:5, bdims)) : rand(inT, bdims)
13+
14+
ma = MtlArray(a)
15+
mb = MtlArray(b)
16+
mc = fill!(similar(mb, outT), zero(outT))
17+
18+
@with (Metal.matmul_alg => alg) mul!(mc,ma,mb)
19+
20+
return all((outT.(a)*outT.(b)) .≈ Array(mc))
1421
end
1522

16-
# Unsupported for MPS and MPSGraph
1723
for vec_b in (true, false)
18-
@test_throws "Matrix multiplication algorithm `:MPS`" test_matmul(Int8, Int16; vec_b, alg=:MPS)
19-
@test_throws "Matrix multiplication algorithm `:MPSGraph`" test_matmul(Int8, Int16; vec_b, alg=:MPSGraph)
24+
@testset let vec_b = vec_b
25+
# Unsupported for MPS and MPSGraph
26+
@test_throws "Matrix-$(vec_b ? "Vector" : "Matrix") multiplication algorithm `:MPS`" test_matmul(Int8, Int16; vec_b, alg=:MPS)
27+
@test_throws "Matrix-$(vec_b ? "Vector" : "Matrix") multiplication algorithm `:MPSGraph`" test_matmul(Int8, Int16; vec_b, alg=:MPSGraph)
2028

2129
# Invalid algorithm Symbol
2230
@test_throws ":bad is not a valid matmul algorithm." test_matmul(Int8, Int16; vec_b, alg=:bad)
2331
@test_throws ":bad is not a valid matmul algorithm." test_matmul(Float16, Float16; vec_b, alg=:bad)
32+
33+
# :auto
34+
@test test_matmul(Int32, Int32; vec_b) # fallback to GPUArrays
35+
@test test_matmul(Int8, Float32; vec_b) # should use MPS
36+
@test test_matmul(Float16, Float32; vec_b) # should use MPSGraph on M1/M2
37+
38+
# :MPS
39+
mpsInT = vec_b ? Float32 : Int16
40+
@test test_matmul(mpsInT, Float32; vec_b, alg=:MPS)
41+
@test test_matmul(Float16, Float32; vec_b, alg=:MPS)
42+
43+
# :MPSGraph
44+
@test test_matmul(Int8, Float32; vec_b, alg=:MPSGraph)
45+
@test test_matmul(Float16, Float32; vec_b, alg=:MPSGraph)
46+
47+
# :GPUArrays
48+
@test test_matmul(Int32, Int32; vec_b, alg=:GPUArrays)
49+
@test test_matmul(Int8, Float32; vec_b, alg=:GPUArrays)
50+
@test test_matmul(Float16, Float32; vec_b, alg=:GPUArrays)
51+
end
2452
end
2553
end
2654

0 commit comments

Comments
 (0)