@@ -6,7 +6,7 @@ using LinearAlgebra
6
6
# ###############################################################################
7
7
8
8
@testset " Matmul API" begin
9
- @testset " FPU GEMM $(A_type) *$(B_type) + $(CD_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) OP ($(OP_M) , $(OP_N) , $(OP_K) )" for
9
+ @testset " FPU GEMM $(A_type) *$(B_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) OP ($(OP_M) , $(OP_N) , $(OP_K) )" for
10
10
(A_type, B_type, CD_type, min_dimension) in [
11
11
(Float16, Float16, Float32, 128 ), (Float32, Float32, Float32, 128 ), (Float32, Float32, Float64, 128 ), (Float64, Float64, Float64, 128 ),
12
12
(Int16, Int16, Int16, 128 ), (Int32, Int32, Int32, 128 ), (Int64, Int64, Int64, 128 ),
@@ -63,10 +63,11 @@ using LinearAlgebra
63
63
new_a_h = transpose_a ? transpose (a_h) : a_h
64
64
new_b_h = transpose_b ? transpose (b_h) : b_h
65
65
66
+ mul! (c_h, new_a_h, new_b_h, alpha, beta)
66
67
if A_type <: Integer
67
- @test all ( isapprox .(alpha * CD_type .(new_a_h) * CD_type .(new_b_h) + beta * c_h, Array (d)) )
68
+ @test c_h ≈ Array (d)
68
69
else
69
- @test all ( isapprox .(alpha * CD_type .(new_a_h) * CD_type .(new_b_h) + beta * c_h, Array (d); rtol = sqrt (eps (A_type)) ))
70
+ @test c_h ≈ Array (d) rtol= sqrt (eps (A_type))
70
71
end
71
72
end
72
73
end
@@ -120,11 +121,12 @@ using LinearAlgebra
120
121
new_a_h = transpose_a ? transpose (a_h) : a_h
121
122
new_b_h = transpose_b ? transpose (b_h) : b_h
122
123
123
- @test all (isapprox .(alpha * CD_type .(new_a_h) * CD_type .(new_b_h) + beta * c_h, Array (d); rtol = sqrt (eps (A_type))))
124
+ mul! (c_h, new_a_h, new_b_h, alpha, beta)
125
+ @test c_h ≈ Array (d) rtol= sqrt (eps (A_type))
124
126
end
125
127
end
126
128
127
- @testset " TROPICAL GEMM $(A_type) *$(B_type) + $(CD_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) OP ($(OP_M) , $(OP_N) , $(OP_K) )" for
129
+ @testset " TROPICAL GEMM $(A_type) *$(B_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) ) OP ($(OP_M) , $(OP_N) , $(OP_K) )" for
128
130
(A_type, B_type, CD_type, min_dimension) in [(Float32, Float32, Float32, 128 )],
129
131
transpose_a = [false , true ],
130
132
transpose_b = [false , true ],
@@ -172,12 +174,12 @@ using LinearAlgebra
172
174
173
175
GemmKernels. matmul (a, b, c, d, conf; kernel = Kernel. matmul_pipelined)
174
176
175
- @test all ( isapprox .( d_h, Array (d); rtol = sqrt (eps (A_type)) ))
177
+ @test d_h ≈ Array (d) rtol= sqrt (eps (A_type))
176
178
end
177
179
end
178
180
179
181
180
- @testset " WMMA GEMM $(AB_type) *$(AB_type) + $(CD_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) )" for transpose_a = [false , true ],
182
+ @testset " WMMA GEMM $(AB_type) *$(AB_type) =$(CD_type) ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) )" for transpose_a = [false , true ],
181
183
transpose_b = [false , true ],
182
184
(AB_type, CD_type, min_dimension) in [(Float16, Float16, 256 ), (Float16, Float32, 128 )]
183
185
@testcase " (M = $M , N = $N , K = $K )" for (M, N, K) in vcat (min_dimension.* [[1 ,1 ,1 ], [2 ,2 ,1 ], [1 ,1 ,2 ], [2 ,2 ,2 ]], [[2048 , 2048 , 2048 ]])
@@ -220,7 +222,8 @@ using LinearAlgebra
220
222
new_a_h = transpose_a ? transpose (a_h) : a_h
221
223
new_b_h = transpose_b ? transpose (b_h) : b_h
222
224
223
- @test all (isapprox .(alpha * CD_type .(new_a_h) * CD_type .(new_b_h) + beta * c_h, Array (d); rtol = sqrt (eps (AB_type))))
225
+ mul! (c_h, new_a_h, new_b_h, alpha, beta)
226
+ @test c_h ≈ Array (d) rtol= sqrt (eps (AB_type))
224
227
end
225
228
end
226
229
@@ -271,7 +274,8 @@ using LinearAlgebra
271
274
new_a_h = transpose_a ? transpose (a_h) : a_h
272
275
new_b_h = transpose_b ? transpose (b_h) : b_h
273
276
274
- @test all (isapprox .(Float32 .(new_a_h) * Float32 .(new_b_h) + c_h .+ Array (bias), Array (d); rtol = sqrt (eps (Float16))))
277
+ mul! (c_h, new_a_h, new_b_h, true , true )
278
+ @test c_h .+ Array (bias) ≈ Array (d) rtol= sqrt (eps (Float16))
275
279
end
276
280
end
277
281
@@ -281,7 +285,7 @@ using LinearAlgebra
281
285
282
286
transpose_a = false
283
287
284
- a_h = rand (Float16, M);
288
+ a_h = rand (Float16, M)
285
289
b_h = rand (Float16, (K, N)) / sqrt (Float16 (K))
286
290
c_h = rand (Float32, (M, N))
287
291
@@ -315,26 +319,27 @@ using LinearAlgebra
315
319
new_a_h = transpose_a ? transpose (a_h) : a_h
316
320
new_b_h = transpose_b ? transpose (b_h) : b_h
317
321
318
- @test all (isapprox .(Float32 .(Diagonal (new_a_h)) * Float32 .(new_b_h) + c_h, Array (d); rtol = sqrt (eps (Float16))))
322
+ mul! (c_h, Diagonal (new_a_h), new_b_h, true , true )
323
+ @test c_h ≈ Array (d) rtol= sqrt (eps (Float16))
319
324
end
320
325
end
321
326
322
327
@testset " WMMA Complex GEMM ($( ! transpose_a ? ' N' : ' T' )$( ! transpose_b ? ' N' : ' T' ) )" for transpose_a = [false , true ],
323
328
transpose_b = [false , true ]
324
329
325
330
@testcase " (M = $M , N = $N , K = $K )" for (M, N, K) = [(128 , 128 , 128 ), (256 , 256 , 256 ), (2048 , 2048 , 2048 )]
326
- a_h = rand (Complex{Float16}, (M, K)) / sqrt (Float16 (K));
327
- b_h = rand (Complex{Float16}, (K, N)) / sqrt (Float16 (K));
328
- c_h = rand (Complex{Float32}, (M, N));
331
+ a_h = rand (Complex{Float16}, (M, K)) / sqrt (Float16 (K))
332
+ b_h = rand (Complex{Float16}, (K, N)) / sqrt (Float16 (K))
333
+ c_h = rand (Complex{Float32}, (M, N))
329
334
330
335
# Transpose input if necessary
331
336
a_h = transpose_a ? transpose (a_h) : a_h
332
337
b_h = transpose_b ? transpose (b_h) : b_h
333
338
334
- a = CuArray (a_h);
335
- b = CuArray (b_h);
336
- c = CuArray (c_h);
337
- d = similar (c);
339
+ a = CuArray (a_h)
340
+ b = CuArray (b_h)
341
+ c = CuArray (c_h)
342
+ d = similar (c)
338
343
339
344
conf = GemmKernels. get_config (
340
345
gemm_shape = (M = M, N = N, K = K),
@@ -378,22 +383,21 @@ using LinearAlgebra
378
383
new_a_h = transpose_a ? transpose (new_a_h) : new_a_h
379
384
new_b_h = transpose_b ? transpose (new_b_h) : new_b_h
380
385
381
- # TODO : Figure out why changing this to a * b + c = d instead of a * b = d - c
382
- # makes tests fail for CC (see #19).
383
- @test all (isapprox .(Complex {Float32} .(new_a_h) * Complex {Float32} .(new_b_h), Array (d) - c_h; rtol= sqrt (eps (Float16))));
386
+ mul! (c_h, new_a_h, new_b_h, true , true )
387
+ @test c_h ≈ Array (d) rtol= sqrt (eps (Float16))
384
388
end
385
389
end
386
390
387
391
@testset " WMMA Dual GEMM" begin
388
392
@testcase " (M = $M , N = $N , K = $K )" for (M, N, K) in [(128 , 128 , 128 ), (256 , 256 , 256 ), (2048 , 2048 , 2048 )]
389
- a_h = rand (Complex{Float16}, (M, K)) / sqrt (Float16 (K));
390
- b_h = rand (Complex{Float16}, (K, N)) / sqrt (Float16 (K));
391
- c_h = rand (Complex{Float32}, (M, N));
393
+ a_h = rand (Complex{Float16}, (M, K)) / sqrt (Float16 (K))
394
+ b_h = rand (Complex{Float16}, (K, N)) / sqrt (Float16 (K))
395
+ c_h = rand (Complex{Float32}, (M, N))
392
396
393
- a = CuArray (a_h);
394
- b = CuArray (b_h);
395
- c = CuArray (c_h);
396
- d = similar (c);
397
+ a = CuArray (a_h)
398
+ b = CuArray (b_h)
399
+ c = CuArray (c_h)
400
+ d = similar (c)
397
401
398
402
conf = GemmKernels. get_config (
399
403
gemm_shape = (M = M, N = N, K = K),
@@ -432,7 +436,8 @@ using LinearAlgebra
432
436
c_dual = reinterpret (ForwardDiff. Dual{Float32,Float32,1 }, c_h)
433
437
d_dual = reinterpret (ForwardDiff. Dual{Float32,Float32,1 }, Array (d))
434
438
435
- @test all (isapprox .(a_dual * b_dual + c_dual, d_dual; rtol= sqrt (eps (Float16))));
439
+ mul! (c_dual, a_dual, b_dual, true , true )
440
+ @test c_dual ≈ d_dual rtol= sqrt (eps (Float16))
436
441
end
437
442
end
438
443
end
0 commit comments