@@ -6,15 +6,15 @@ using LinearAlgebra
6
6
# ###############################################################################
7
7
8
8
@testset " Matmul API" begin
9
- @test_if " fpu compute and data types " @ testset " FPU GEMM $(A_type) *$(B_type) +$(CD_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) OP ($(OP_M) , $(OP_N) , $(OP_K) )" for
9
+ @testset " FPU GEMM $(A_type) *$(B_type) +$(CD_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) OP ($(OP_M) , $(OP_N) , $(OP_K) )" for
10
10
(A_type, B_type, CD_type, min_dimension) in [
11
11
(Float16, Float16, Float32, 128 ), (Float32, Float32, Float32, 128 ), (Float32, Float32, Float64, 128 ), (Float64, Float64, Float64, 128 ),
12
- (Int16, Int16, Int16, 128 ), (Int32, Int32, Int32, 128 ), (Int64, Int64, Int64, 128 ),
13
- ],
14
- transpose_a = [false , true ],
15
- transpose_b = [false , true ],
12
+ (Int16, Int16, Int16, 128 ), (Int32, Int32, Int32, 128 ), (Int64, Int64, Int64, 128 ),
13
+ ],
14
+ transpose_a = [false , true ],
15
+ transpose_b = [false , true ],
16
16
(OP_M, OP_N, OP_K) in [(8 , 16 , 2 )]
17
- @testset " (M = $M , N = $N , K = $K )" for (M, N, K) in vcat (min_dimension.* [[1 ,1 ,1 ], [2 , 2 , 1 ], [1 , 1 , 2 ], [2 , 2 , 2 ]], [[2048 , 2048 , 2048 ]])
17
+ @testcase " (M = $M , N = $N , K = $K )" for (M, N, K) in vcat (min_dimension.* [[1 ,1 ,1 ], [2 , 2 , 1 ], [1 , 1 , 2 ], [2 , 2 , 2 ]], [[2048 , 2048 , 2048 ]])
18
18
alpha = convert (A_type, 2 )
19
19
beta = convert (CD_type, 3 )
20
20
@@ -59,7 +59,7 @@ using LinearAlgebra
59
59
# Transpose outputs, if necessary
60
60
new_a_h = transpose_a ? transpose (a_h) : a_h
61
61
new_b_h = transpose_b ? transpose (b_h) : b_h
62
-
62
+
63
63
if A_type <: Integer
64
64
@test all (isapprox .(alpha * CD_type .(new_a_h) * CD_type .(new_b_h) + beta * c_h, Array (d)))
65
65
else
@@ -68,13 +68,13 @@ using LinearAlgebra
68
68
end
69
69
end
70
70
71
- @test_if " fpu operator shape " @ testset " FPU GEMM OPERATOR SHAPE ($(OP_M) , $(OP_N) , $(OP_K) ) (NN, NT, TN, TT)" for (OP_M, OP_N, OP_K) in [
72
- (4 , 8 , 1 ), (8 , 8 , 1 ), (4 , 16 , 1 ), (4 , 8 , 2 ), (8 , 16 , 2 )
71
+ @testset " FPU GEMM OPERATOR SHAPE ($(OP_M) , $(OP_N) , $(OP_K) ) (NN, NT, TN, TT)" for (OP_M, OP_N, OP_K) in [
72
+ (4 , 8 , 1 ), (8 , 8 , 1 ), (4 , 16 , 1 ), (4 , 8 , 2 ), (8 , 16 , 2 )
73
73
]
74
- @testset " NN, NT, TN, TT" for (transpose_a, transpose_b) in [(false , false ), (false , true ), (true , false ), (true , true )]
74
+ @testcase " NN, NT, TN, TT" for (transpose_a, transpose_b) in [(false , false ), (false , true ), (true , false ), (true , true )]
75
75
(M, N, K) = (128 , 128 , 128 )
76
76
(A_type, B_type, CD_type) = (Float32, Float32, Float32)
77
-
77
+
78
78
alpha = convert (A_type, 2 )
79
79
beta = convert (CD_type, 3 )
80
80
@@ -114,18 +114,18 @@ using LinearAlgebra
114
114
# Transpose outputs, if necessary
115
115
new_a_h = transpose_a ? transpose (a_h) : a_h
116
116
new_b_h = transpose_b ? transpose (b_h) : b_h
117
-
117
+
118
118
@test all (isapprox .(alpha * CD_type .(new_a_h) * CD_type .(new_b_h) + beta * c_h, Array (d); rtol = sqrt (eps (A_type))))
119
119
end
120
120
end
121
121
122
- @test_if " tropical fpu " @ testset " TROPICAL GEMM $(A_type) *$(B_type) +$(CD_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) OP ($(OP_M) , $(OP_N) , $(OP_K) )" for
123
- (A_type, B_type, CD_type, min_dimension) in [(Float32, Float32, Float32, 128 )],
124
- transpose_a = [false , true ],
125
- transpose_b = [false , true ],
122
+ @testset " TROPICAL GEMM $(A_type) *$(B_type) +$(CD_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) OP ($(OP_M) , $(OP_N) , $(OP_K) )" for
123
+ (A_type, B_type, CD_type, min_dimension) in [(Float32, Float32, Float32, 128 )],
124
+ transpose_a = [false , true ],
125
+ transpose_b = [false , true ],
126
126
(OP_M, OP_N, OP_K) in [(8 , 16 , 2 )]
127
127
128
- @testset " (M = $M , N = $N , K = $K )" for (M, N, K) in vcat (min_dimension.* [[1 ,1 ,1 ], [2 , 2 , 1 ], [1 , 1 , 2 ], [2 , 2 , 2 ]])
128
+ @testcase " (M = $M , N = $N , K = $K )" for (M, N, K) in vcat (min_dimension.* [[1 ,1 ,1 ], [2 , 2 , 1 ], [1 , 1 , 2 ], [2 , 2 , 2 ]])
129
129
a_h = rand (A_type, (M, K)) / sqrt (A_type (K))
130
130
b_h = rand (B_type, (K, N)) / sqrt (B_type (K))
131
131
c_h = rand (CD_type, (M, N))
@@ -135,7 +135,7 @@ using LinearAlgebra
135
135
for j in 1 : N
136
136
d_h[i, j] = c_h[i, j]
137
137
for k in 1 : K
138
- d_h[i, j] = max (a_h[i, k] + b_h[k, j], d_h[i, j])
138
+ d_h[i, j] = max (a_h[i, k] + b_h[k, j], d_h[i, j])
139
139
end
140
140
end
141
141
end
@@ -164,16 +164,16 @@ using LinearAlgebra
164
164
)
165
165
166
166
GemmKernels. matmul (a, b, c, d, conf; kernel = Kernel. matmul_pipelined)
167
-
167
+
168
168
@test all (isapprox .(d_h, Array (d); rtol = sqrt (eps (A_type))))
169
169
end
170
170
end
171
171
172
172
173
- @test_if " wmma " @ testset " WMMA GEMM $(A_type) *$(B_type) +$(CD_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) )" for transpose_a = [false , true ],
173
+ @testset " WMMA GEMM $(A_type) *$(B_type) +$(CD_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) )" for transpose_a = [false , true ],
174
174
transpose_b = [false , true ],
175
175
(A_type, B_type, CD_type, min_dimension) in [(Float16, Float16, Float16, 256 ), (Float16, Float16, Float32, 128 )]
176
- @testset " (M = $M , N = $N , K = $K )" for (M, N, K) in vcat (min_dimension.* [[1 ,1 ,1 ], [2 ,2 ,1 ], [1 ,1 ,2 ], [2 ,2 ,2 ]], [[2048 , 2048 , 2048 ]])
176
+ @testcase " (M = $M , N = $N , K = $K )" for (M, N, K) in vcat (min_dimension.* [[1 ,1 ,1 ], [2 ,2 ,1 ], [1 ,1 ,2 ], [2 ,2 ,2 ]], [[2048 , 2048 , 2048 ]])
177
177
alpha = convert (A_type, 2 )
178
178
beta = convert (CD_type, 3 )
179
179
@@ -217,10 +217,10 @@ using LinearAlgebra
217
217
end
218
218
end
219
219
220
- @test_if " bias " @ testset " WMMA GEMM ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) + bias" for transpose_a = [false , true ],
220
+ @testset " WMMA GEMM ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) + bias" for transpose_a = [false , true ],
221
221
transpose_b = [false , true ]
222
222
223
- @testset " (M = $M , N = $N , K = $K )" for (M, N, K) in [(128 , 128 , 128 ), (256 , 256 , 256 ), (4096 , 4096 , 4096 )]
223
+ @testcase " (M = $M , N = $N , K = $K )" for (M, N, K) in [(128 , 128 , 128 ), (256 , 256 , 256 ), (4096 , 4096 , 4096 )]
224
224
a_h = rand (Float16, (M, K)) / sqrt (Float16 (K))
225
225
b_h = rand (Float16, (K, N)) / sqrt (Float16 (K))
226
226
c_h = rand (Float32, (M, N))
@@ -268,8 +268,8 @@ using LinearAlgebra
268
268
end
269
269
end
270
270
271
- @test_if " diagonal " @ testset " WMMA GEMM (A = diagonal, B = $( ! transpose_b ? ' N' : ' T' ) )" for transpose_b = [false , true ]
272
- @testset " (M = $M , N = $N , K = $K )" for (M, N, K) in [(128 , 128 , 128 ), (256 , 256 , 256 ), (4096 , 4096 , 4096 )]
271
+ @testset " WMMA GEMM (A = diagonal, B = $( ! transpose_b ? ' N' : ' T' ) )" for transpose_b = [false , true ]
272
+ @testcase " (M = $M , N = $N , K = $K )" for (M, N, K) in [(128 , 128 , 128 ), (256 , 256 , 256 ), (4096 , 4096 , 4096 )]
273
273
@assert M == K " Diagonal only supports square A matrix (M == K)"
274
274
275
275
transpose_a = false
@@ -312,10 +312,10 @@ using LinearAlgebra
312
312
end
313
313
end
314
314
315
- @test_if " complex " @ testset " WMMA Complex GEMM ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) )" for transpose_a = [false , true ],
315
+ @testset " WMMA Complex GEMM ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) )" for transpose_a = [false , true ],
316
316
transpose_b = [false , true ]
317
317
318
- @testset " (M = $M , N = $N , K = $K )" for (M, N, K) = [(128 , 128 , 128 ), (256 , 256 , 256 ), (2048 , 2048 , 2048 )]
318
+ @testcase " (M = $M , N = $N , K = $K )" for (M, N, K) = [(128 , 128 , 128 ), (256 , 256 , 256 ), (2048 , 2048 , 2048 )]
319
319
a_h = rand (Complex{Float16}, (M, K)) / sqrt (Float16 (K));
320
320
b_h = rand (Complex{Float16}, (K, N)) / sqrt (Float16 (K));
321
321
c_h = rand (Complex{Float32}, (M, N));
@@ -377,8 +377,8 @@ using LinearAlgebra
377
377
end
378
378
end
379
379
380
- @test_if " dual " @ testset " WMMA Dual GEMM" begin
381
- @testset " (M = $M , N = $N , K = $K )" for (M, N, K) in [(128 , 128 , 128 ), (256 , 256 , 256 ), (2048 , 2048 , 2048 )]
380
+ @testset " WMMA Dual GEMM" begin
381
+ @testcase " (M = $M , N = $N , K = $K )" for (M, N, K) in [(128 , 128 , 128 ), (256 , 256 , 256 ), (2048 , 2048 , 2048 )]
382
382
a_h = rand (Complex{Float16}, (M, K)) / sqrt (Float16 (K));
383
383
b_h = rand (Complex{Float16}, (K, N)) / sqrt (Float16 (K));
384
384
c_h = rand (Complex{Float32}, (M, N));
0 commit comments