Skip to content

Commit 1e1e193

Browse files
Fix linalg tests for MPS and MPSGraph (#618)
1 parent b8fb812 commit 1e1e193

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

test/mps/linalg.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ if MPS.is_supported(device())
1616

1717
for (input_jl_type, accum_jl_type) in MPS.MPS_VALID_MATMUL_TYPES
1818
@testset let input_jl_type = input_jl_type, accum_jl_type = accum_jl_type
19-
arr_a = rand(input_jl_type, (rows_a,cols_a))
20-
arr_b = rand(input_jl_type, (rows_b,cols_b))
21-
arr_c = zeros(accum_jl_type, (rows_c,cols_c))
19+
arr_a = rand(input_jl_type, (rows_a, cols_a))
20+
arr_b = rand(input_jl_type, (rows_b, cols_b))
21+
arr_c = zeros(accum_jl_type, (rows_c, cols_c))
2222

2323
buf_a = MtlArray{input_jl_type}(arr_a)
2424
buf_b = MtlArray{input_jl_type}(arr_b)
25-
buf_c = MtlArray{accum_jl_type}(undef, (rows_c,cols_c))
25+
buf_c = Metal.zeros(accum_jl_type, size(arr_c))
2626

27-
truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
27+
truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
2828

2929
MPS.matmul!(buf_c, buf_a, buf_b, alpha, beta)
3030

@@ -59,16 +59,16 @@ end
5959

6060
buf_a = MtlArray{input_jl_type}(arr_a)
6161
buf_b = MtlArray{input_jl_type}(arr_b)
62-
buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
62+
buf_c = Metal.zeros(accum_jl_type, (rows_c, cols_c, batch_size))
6363

64-
truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
64+
truth_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size))
6565
for i in 1:batch_size
6666
@views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
6767
end
6868

6969
MPS.matmul!(buf_c, buf_a, buf_b, alpha, beta)
7070

71-
@test all(Array(buf_c) . truth_c)
71+
@test Array(buf_c) truth_c
7272
end
7373
end
7474
end
@@ -83,18 +83,18 @@ end
8383

8484
@testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPS.MPS_VALID_MATVECMUL_TYPES
8585
arr_a = rand(input_jl_type, (rows,cols))
86-
arr_b = rand(input_jl_type, (rows))
87-
arr_c = zeros(accum_jl_type, (rows))
86+
arr_b = rand(input_jl_type, (rows,))
87+
arr_c = zeros(accum_jl_type, (rows,))
8888

8989
buf_a = MtlArray{input_jl_type}(arr_a)
9090
buf_b = MtlArray{input_jl_type}(arr_b)
91-
buf_c = MtlArray{accum_jl_type}(undef, (rows))
91+
buf_c = Metal.zeros(accum_jl_type, (rows,))
9292

9393
truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
9494

9595
MPS.matvecmul!(buf_c, buf_a, buf_b, alpha, beta)
9696

97-
@test all(Array(buf_c) . truth_c)
97+
@test Array(buf_c) truth_c
9898
end
9999
end
100100

test/mpsgraphs/linalg.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@ if MPS.is_supported(device())
1414
alpha = Float64(1)
1515
beta = Float64(1)
1616

17-
@testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
18-
arr_a = rand(input_jl_type, (rows_a, cols_a))
19-
arr_b = rand(input_jl_type, (rows_b, cols_b))
20-
arr_c = zeros(accum_jl_type, (rows_c, cols_c))
17+
for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
18+
@testset let input_jl_type = input_jl_type, accum_jl_type = accum_jl_type
19+
arr_a = rand(input_jl_type, (rows_a, cols_a))
20+
arr_b = rand(input_jl_type, (rows_b, cols_b))
21+
arr_c = zeros(accum_jl_type, (rows_c, cols_c))
2122

22-
buf_a = MtlArray{input_jl_type}(arr_a)
23-
buf_b = MtlArray{input_jl_type}(arr_b)
24-
buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c))
23+
buf_a = MtlArray{input_jl_type}(arr_a)
24+
buf_b = MtlArray{input_jl_type}(arr_b)
25+
buf_c = Metal.zeros(accum_jl_type, size(arr_c))
2526

26-
truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
27+
truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
2728

28-
MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
29+
MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
2930

30-
@test all(Array(buf_c) .≈ truth_c)
31+
@test Array(buf_c) truth_c
32+
end
3133
end
3234
end
3335

@@ -56,16 +58,16 @@ end
5658

5759
buf_a = MtlArray{input_jl_type}(arr_a)
5860
buf_b = MtlArray{input_jl_type}(arr_b)
59-
buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
61+
buf_c = Metal.zeros(accum_jl_type, (rows_c, cols_c, batch_size))
6062

61-
truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
63+
truth_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size))
6264
for i in 1:batch_size
6365
@views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
6466
end
6567

6668
MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
6769

68-
@test all(Array(buf_c) . truth_c)
70+
@test Array(buf_c) truth_c
6971
end
7072
end
7173

@@ -79,18 +81,18 @@ end
7981

8082
@testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATVECMUL_TYPES
8183
arr_a = rand(input_jl_type, (rows, cols))
82-
arr_b = rand(input_jl_type, (rows))
83-
arr_c = zeros(accum_jl_type, (rows))
84+
arr_b = rand(input_jl_type, rows)
85+
arr_c = zeros(accum_jl_type, rows)
8486

8587
buf_a = MtlArray{input_jl_type}(arr_a)
8688
buf_b = MtlArray{input_jl_type}(arr_b)
87-
buf_c = MtlArray{accum_jl_type}(undef, (rows))
89+
buf_c = Metal.zeros(accum_jl_type, rows)
8890

8991
truth_c = (accum_jl_type(alpha) .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (accum_jl_type(beta) .* arr_c)
9092

9193
MPSGraphs.graph_matvecmul!(buf_c, buf_a, buf_b, alpha, beta)
9294

93-
@test all(Array(buf_c) . truth_c)
95+
@test Array(buf_c) truth_c
9496
end
9597
end
9698

0 commit comments

Comments
 (0)