Skip to content

Commit 781f1de

Browse files
authored
Simplify tests. (#124)
1 parent fa335a1 commit 781f1de

File tree

3 files changed

+41
-36
lines changed

3 files changed

+41
-36
lines changed

test/blas.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@ using LinearAlgebra
55
CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH)
66

77
@testset "BLAS API" begin
8-
@testset "WMMA GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
8+
@testset "WMMA GEMM $(AB_type)*$(AB_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
99
transpose_b = [false, true],
10-
(A_type, B_type, CD_type, min_dimension) in [(Float16, Float16, Float16, 256), (Float16, Float16, Float32, 128)]
10+
(AB_type, CD_type, min_dimension) in [(Float16, Float16, 256), (Float16, Float32, 128)]
1111

1212
@testcase "(M = $M, N = $N, K = $K)" for M in min_dimension .* [1, 2],
1313
N in min_dimension .* [1, 2],
1414
K in min_dimension .* [1, 2]
1515

16-
alpha = rand(A_type)
16+
alpha = rand(AB_type)
1717
beta = rand(CD_type)
1818

19-
a_h = rand(A_type, (M, K)) / sqrt(A_type(K))
20-
b_h = rand(B_type, (K, N)) / sqrt(B_type(K))
19+
a_h = rand(AB_type, (M, K)) / sqrt(AB_type(K))
20+
b_h = rand(AB_type, (K, N)) / sqrt(AB_type(K))
2121
c_h = rand(CD_type, (M, N))
2222

2323
# Transpose input if necessary
@@ -33,7 +33,7 @@ CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH)
3333
c_cublas = CuArray(c_h)
3434
CUDA.CUBLAS.gemmEx!(!transpose_a ? 'N' : 'T', !transpose_b ? 'N' : 'T', alpha, a, b, beta, c_cublas)
3535

36-
@test all(isapprox.(Array(c_gemmkernels), Array(c_cublas); rtol=sqrt(eps(A_type))));
36+
@test Array(c_gemmkernels) Array(c_cublas) rtol=sqrt(eps(AB_type))
3737
end
3838
end
3939
end

test/matmul.jl

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using LinearAlgebra
66
################################################################################
77

88
@testset "Matmul API" begin
9-
@testset "FPU GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for
9+
@testset "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for
1010
(A_type, B_type, CD_type, min_dimension) in [
1111
(Float16, Float16, Float32, 128), (Float32, Float32, Float32, 128), (Float32, Float32, Float64, 128), (Float64, Float64, Float64, 128),
1212
(Int16, Int16, Int16, 128), (Int32, Int32, Int32, 128), (Int64, Int64, Int64, 128),
@@ -63,10 +63,11 @@ using LinearAlgebra
6363
new_a_h = transpose_a ? transpose(a_h) : a_h
6464
new_b_h = transpose_b ? transpose(b_h) : b_h
6565

66+
mul!(c_h, new_a_h, new_b_h, alpha, beta)
6667
if A_type <: Integer
67-
@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d)))
68+
@test c_h Array(d)
6869
else
69-
@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(A_type))))
70+
@test c_h Array(d) rtol=sqrt(eps(A_type))
7071
end
7172
end
7273
end
@@ -120,11 +121,12 @@ using LinearAlgebra
120121
new_a_h = transpose_a ? transpose(a_h) : a_h
121122
new_b_h = transpose_b ? transpose(b_h) : b_h
122123

123-
@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(A_type))))
124+
mul!(c_h, new_a_h, new_b_h, alpha, beta)
125+
@test c_h Array(d) rtol=sqrt(eps(A_type))
124126
end
125127
end
126128

127-
@testset "TROPICAL GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for
129+
@testset "TROPICAL GEMM $(A_type)*$(B_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for
128130
(A_type, B_type, CD_type, min_dimension) in [(Float32, Float32, Float32, 128)],
129131
transpose_a = [false, true],
130132
transpose_b = [false, true],
@@ -172,12 +174,12 @@ using LinearAlgebra
172174

173175
GemmKernels.matmul(a, b, c, d, conf; kernel = Kernel.matmul_pipelined)
174176

175-
@test all(isapprox.(d_h, Array(d); rtol = sqrt(eps(A_type))))
177+
@test d_h Array(d) rtol=sqrt(eps(A_type))
176178
end
177179
end
178180

179181

180-
@testset "WMMA GEMM $(AB_type)*$(AB_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
182+
@testset "WMMA GEMM $(AB_type)*$(AB_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
181183
transpose_b = [false, true],
182184
(AB_type, CD_type, min_dimension) in [(Float16, Float16, 256), (Float16, Float32, 128)]
183185
@testcase "(M = $M, N = $N, K = $K)" for (M, N, K) in vcat(min_dimension.*[[1,1,1], [2,2,1], [1,1,2], [2,2,2]], [[2048, 2048, 2048]])
@@ -220,7 +222,8 @@ using LinearAlgebra
220222
new_a_h = transpose_a ? transpose(a_h) : a_h
221223
new_b_h = transpose_b ? transpose(b_h) : b_h
222224

223-
@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(AB_type))))
225+
mul!(c_h, new_a_h, new_b_h, alpha, beta)
226+
@test c_h Array(d) rtol=sqrt(eps(AB_type))
224227
end
225228
end
226229

@@ -271,7 +274,8 @@ using LinearAlgebra
271274
new_a_h = transpose_a ? transpose(a_h) : a_h
272275
new_b_h = transpose_b ? transpose(b_h) : b_h
273276

274-
@test all(isapprox.(Float32.(new_a_h) * Float32.(new_b_h) + c_h .+ Array(bias), Array(d); rtol = sqrt(eps(Float16))))
277+
mul!(c_h, new_a_h, new_b_h, true, true)
278+
@test c_h .+ Array(bias) Array(d) rtol=sqrt(eps(Float16))
275279
end
276280
end
277281

@@ -281,7 +285,7 @@ using LinearAlgebra
281285

282286
transpose_a = false
283287

284-
a_h = rand(Float16, M);
288+
a_h = rand(Float16, M)
285289
b_h = rand(Float16, (K, N)) / sqrt(Float16(K))
286290
c_h = rand(Float32, (M, N))
287291

@@ -315,26 +319,27 @@ using LinearAlgebra
315319
new_a_h = transpose_a ? transpose(a_h) : a_h
316320
new_b_h = transpose_b ? transpose(b_h) : b_h
317321

318-
@test all(isapprox.(Float32.(Diagonal(new_a_h)) * Float32.(new_b_h) + c_h, Array(d); rtol = sqrt(eps(Float16))))
322+
mul!(c_h, Diagonal(new_a_h), new_b_h, true, true)
323+
@test c_h Array(d) rtol=sqrt(eps(Float16))
319324
end
320325
end
321326

322327
@testset "WMMA Complex GEMM ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
323328
transpose_b = [false, true]
324329

325330
@testcase "(M = $M, N = $N, K = $K)" for (M, N, K) = [(128, 128, 128), (256, 256, 256), (2048, 2048, 2048)]
326-
a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K));
327-
b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K));
328-
c_h = rand(Complex{Float32}, (M, N));
331+
a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K))
332+
b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K))
333+
c_h = rand(Complex{Float32}, (M, N))
329334

330335
# Transpose input if necessary
331336
a_h = transpose_a ? transpose(a_h) : a_h
332337
b_h = transpose_b ? transpose(b_h) : b_h
333338

334-
a = CuArray(a_h);
335-
b = CuArray(b_h);
336-
c = CuArray(c_h);
337-
d = similar(c);
339+
a = CuArray(a_h)
340+
b = CuArray(b_h)
341+
c = CuArray(c_h)
342+
d = similar(c)
338343

339344
conf = GemmKernels.get_config(
340345
gemm_shape = (M = M, N = N, K = K),
@@ -378,22 +383,21 @@ using LinearAlgebra
378383
new_a_h = transpose_a ? transpose(new_a_h) : new_a_h
379384
new_b_h = transpose_b ? transpose(new_b_h) : new_b_h
380385

381-
# TODO: Figure out why changing this to a * b + c = d instead of a * b = d - c
382-
# makes tests fail for CC (see #19).
383-
@test all(isapprox.(Complex{Float32}.(new_a_h) * Complex{Float32}.(new_b_h), Array(d) - c_h; rtol=sqrt(eps(Float16))));
386+
mul!(c_h, new_a_h, new_b_h, true, true)
387+
@test c_h Array(d) rtol=sqrt(eps(Float16))
384388
end
385389
end
386390

387391
@testset "WMMA Dual GEMM" begin
388392
@testcase "(M = $M, N = $N, K = $K)" for (M, N, K) in [(128, 128, 128), (256, 256, 256), (2048, 2048, 2048)]
389-
a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K));
390-
b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K));
391-
c_h = rand(Complex{Float32}, (M, N));
393+
a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K))
394+
b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K))
395+
c_h = rand(Complex{Float32}, (M, N))
392396

393-
a = CuArray(a_h);
394-
b = CuArray(b_h);
395-
c = CuArray(c_h);
396-
d = similar(c);
397+
a = CuArray(a_h)
398+
b = CuArray(b_h)
399+
c = CuArray(c_h)
400+
d = similar(c)
397401

398402
conf = GemmKernels.get_config(
399403
gemm_shape = (M = M, N = N, K = K),
@@ -432,7 +436,8 @@ using LinearAlgebra
432436
c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_h)
433437
d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, Array(d))
434438

435-
@test all(isapprox.(a_dual * b_dual + c_dual, d_dual; rtol=sqrt(eps(Float16))));
439+
mul!(c_dual, a_dual, b_dual, true, true)
440+
@test c_dual d_dual rtol=sqrt(eps(Float16))
436441
end
437442
end
438443
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ withenv("JULIA_NUM_THREADS" => 1, "OPENBLAS_NUM_THREADS" => 1) do
2222
end
2323

2424
@everywhere using XUnit
25-
runtests("tests.jl")
25+
runtests("tests.jl", ARGS...)

0 commit comments

Comments
 (0)