@@ -14,20 +14,22 @@ if MPS.is_supported(device())
1414 alpha = Float64 (1 )
1515 beta = Float64 (1 )
1616
17- @testset " $(input_jl_type) => $accum_jl_type " for (input_jl_type, accum_jl_type) in MPSGraphs. MPSGRAPH_VALID_MATMUL_TYPES
18- arr_a = rand (input_jl_type, (rows_a, cols_a))
19- arr_b = rand (input_jl_type, (rows_b, cols_b))
20- arr_c = zeros (accum_jl_type, (rows_c, cols_c))
17+ for (input_jl_type, accum_jl_type) in MPSGraphs. MPSGRAPH_VALID_MATMUL_TYPES
18+ @testset let input_jl_type = input_jl_type, accum_jl_type = accum_jl_type
19+ arr_a = rand (input_jl_type, (rows_a, cols_a))
20+ arr_b = rand (input_jl_type, (rows_b, cols_b))
21+ arr_c = zeros (accum_jl_type, (rows_c, cols_c))
2122
22- buf_a = MtlArray {input_jl_type} (arr_a)
23- buf_b = MtlArray {input_jl_type} (arr_b)
24- buf_c = MtlArray {accum_jl_type} (undef, (rows_c, cols_c ))
23+ buf_a = MtlArray {input_jl_type} (arr_a)
24+ buf_b = MtlArray {input_jl_type} (arr_b)
25+ buf_c = Metal . zeros (accum_jl_type, size (arr_c ))
2526
26- truth_c = (alpha .* accum_jl_type .(arr_a)) * accum_jl_type .(arr_b) .+ (beta .* arr_c)
27+ truth_c = (alpha .* accum_jl_type .(arr_a)) * accum_jl_type .(arr_b) .+ (beta .* arr_c)
2728
28- MPSGraphs. graph_matmul! (buf_c, buf_a, buf_b, alpha, beta)
29+ MPSGraphs. graph_matmul! (buf_c, buf_a, buf_b, alpha, beta)
2930
30- @test all (Array (buf_c) .≈ truth_c)
31+ @test Array (buf_c) ≈ truth_c
32+ end
3133 end
3234end
3335
5658
5759 buf_a = MtlArray {input_jl_type} (arr_a)
5860 buf_b = MtlArray {input_jl_type} (arr_b)
59- buf_c = MtlArray {accum_jl_type} (undef , (rows_c, cols_c, batch_size))
61+ buf_c = Metal . zeros (accum_jl_type , (rows_c, cols_c, batch_size))
6062
61- truth_c = Array {accum_jl_type} (undef , (rows_c, cols_c, batch_size))
63+ truth_c = zeros (accum_jl_type , (rows_c, cols_c, batch_size))
6264 for i in 1 : batch_size
6365 @views truth_c[:, :, i] = (alpha .* accum_jl_type .(arr_a[:, :, i])) * accum_jl_type .(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
6466 end
6567
6668 MPSGraphs. graph_matmul! (buf_c, buf_a, buf_b, alpha, beta)
6769
68- @test all ( Array (buf_c) . ≈ truth_c)
70+ @test Array (buf_c) ≈ truth_c
6971 end
7072end
7173
7981
8082 @testset " $(input_jl_type) => $accum_jl_type " for (input_jl_type, accum_jl_type) in MPSGraphs. MPSGRAPH_VALID_MATVECMUL_TYPES
8183 arr_a = rand (input_jl_type, (rows, cols))
82- arr_b = rand (input_jl_type, ( rows) )
83- arr_c = zeros (accum_jl_type, ( rows) )
84+ arr_b = rand (input_jl_type, rows)
85+ arr_c = zeros (accum_jl_type, rows)
8486
8587 buf_a = MtlArray {input_jl_type} (arr_a)
8688 buf_b = MtlArray {input_jl_type} (arr_b)
87- buf_c = MtlArray {accum_jl_type} (undef, ( rows) )
89+ buf_c = Metal . zeros (accum_jl_type, rows)
8890
8991 truth_c = (accum_jl_type (alpha) .* accum_jl_type .(arr_a)) * accum_jl_type .(arr_b) .+ (accum_jl_type (beta) .* arr_c)
9092
9193 MPSGraphs. graph_matvecmul! (buf_c, buf_a, buf_b, alpha, beta)
9294
93- @test all ( Array (buf_c) . ≈ truth_c)
95+ @test Array (buf_c) ≈ truth_c
9496 end
9597end
9698
0 commit comments