Skip to content

Commit 336724f

Browse files
committed
Remove magic and add correctness to testsuite compare function.
1 parent 5665483 commit 336724f

File tree

8 files changed

+72
-89
lines changed

8 files changed

+72
-89
lines changed

src/testsuite.jl

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,15 @@ using FFTW
1313
using FillArrays
1414
using StaticArrays
1515

16-
toarray(T, x::Tuple{X, Vararg{Int}}) where X = fill(first(x), Base.tail(x))
17-
toarray(::Type{T}, x::NTuple{N, Int}) where {T <: Bool, N} = rand(T, x)
18-
toarray(::Type{T}, x::NTuple{N, Int}) where {T <: Integer, N} = rand(T(1):T(10), x)
19-
toarray(T, x::NTuple{N, Int}) where N = rand(T, x)
20-
toarray(T, x) = x
21-
togpu(T, x::AbstractArray) = T(x)
22-
togpu(T, x) = x
16+
convert_array(f, x) = f(x)
17+
convert_array(f, x::Base.RefValue) = x[]
2318

24-
"""
25-
Calls function `f` on input arrays generated by `sizes` as Base.Array and converted to
26-
`Typ`. Compares the result of `f` and tests if they agree. `sizes` can be the shape of the
27-
array, a value or a tuple `(val, shape...)` which will create a `fill(val, shape...)`.
28-
"""
29-
function against_base(f, Typ, sizes...)
30-
jl_arrays = toarray.(eltype(Typ), sizes)
31-
gpu_arrays = togpu.(Typ, jl_arrays)
32-
res_jl = f(jl_arrays...)
33-
res_gpu = f(gpu_arrays...)
34-
@test res_jl Array(res_gpu)
19+
function compare(f, Typ, xs...)
20+
cpu_in = convert_array.(copy, xs)
21+
gpu_in = convert_array.(Typ, xs)
22+
cpu_out = f(cpu_in...)
23+
gpu_out = f(gpu_in...)
24+
cpu_out Array(gpu_out)
3525
end
3626

3727

@@ -52,7 +42,7 @@ function supported_eltypes()
5242
(Float32, Float64, Int32, Int64, ComplexF32, ComplexF64)
5343
end
5444

55-
export against_base, run_tests, supported_eltypes
45+
export run_tests, supported_eltypes
5646

5747
end
5848

src/testsuite/base.jl

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ function test_base(Typ)
4747
@test Array(A) a
4848
end
4949

50-
5150
@testset "copyto!" begin
5251
x = fill(0f0, (10, 10))
5352
y = rand(Float32, (20, 10))
@@ -69,17 +68,12 @@ function test_base(Typ)
6968
end
7069

7170
@testset "vcat + hcat" begin
72-
x = fill(0f0, (10, 10))
73-
y = rand(Float32, 20, 10)
74-
a = Typ(x)
75-
b = Typ(y)
76-
@test vcat(x, y) == Array(vcat(a, b))
77-
z = rand(Float32, 10, 10)
78-
c = Typ(z)
79-
@test hcat(x, z) == Array(hcat(a, c))
71+
@test compare(vcat, Typ, fill(0f0, (10, 10)), rand(Float32, 20, 10))
72+
@test compare(hcat, Typ, fill(0f0, (10, 10)), rand(Float32, 10, 10))
8073

81-
against_base(hcat, Typ{Float32}, (3, 3), (3, 3))
82-
against_base(vcat, Typ{Float32}, (3, 3), (3, 3))
74+
@test compare(hcat, Typ, rand(Float32, 3, 3), rand(Float32, 3, 3))
75+
@test compare(vcat, Typ, rand(Float32, 3, 3), rand(Float32, 3, 3))
76+
@test compare((a,b) -> cat(a, b; dims=4), Typ, rand(Float32, 3, 4), rand(Float32, 3, 4))
8377
end
8478

8579
@testset "reinterpret" begin
@@ -121,18 +115,17 @@ function test_base(Typ)
121115
@test map!(-, jy, jy) Array(x)
122116
end
123117

124-
T = Typ{Float32}
125118
@testset "map" begin
126-
against_base((a, b)-> map(+, a, b), T, (10,), (10,))
127-
against_base((a, b)-> map!(-, a, b), T, (10,), (10,))
128-
against_base((a, b, c, d)-> map!(*, a, b, c, d), T, (10,), (10,), (10,), (10,))
119+
@test compare((a, b)-> map(+, a, b), Typ, rand(Float32, 10), rand(Float32, 10))
120+
@test compare((a, b)-> map!(-, a, b), Typ, rand(Float32, 10), rand(Float32, 10))
121+
@test compare((a, b, c, d)-> map!(*, a, b, c, d), Typ, rand(Float32, 10), rand(Float32, 10), rand(Float32, 10), rand(Float32, 10))
129122
end
130123

131124
@testset "repeat" begin
132-
against_base(a-> repeat(a, 5, 6), T, (10,))
133-
against_base(a-> repeat(a, 5), T, (10,))
134-
against_base(a-> repeat(a, 5), T, (5, 4))
135-
against_base(a-> repeat(a, 4, 3), T, (10, 15))
125+
@test compare(a-> repeat(a, 5, 6), Typ, rand(Float32, 10))
126+
@test compare(a-> repeat(a, 5), Typ, rand(Float32, 10))
127+
@test compare(a-> repeat(a, 5), Typ, rand(Float32, 5, 4))
128+
@test compare(a-> repeat(a, 4, 3), Typ, rand(Float32, 10, 15))
136129
end
137130
end
138131
end

src/testsuite/blas.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
function test_blas(Typ)
22
@testset "BLAS" begin
3-
T = Typ{Float32}
43
@testset "matmul" begin
5-
against_base(*, T, (5, 5), (5, 5))
6-
against_base(*, T, (5, 5), (5,))
7-
against_base((a, b)-> a * transpose(b), T, (5, 5), (5, 5))
8-
against_base((c, a, b)-> mul!(c, a, transpose(b)), T, (10, 32), (10, 60), (32, 60))
9-
against_base((a, b)-> transpose(a) * b, T, (5, 5), (5, 5))
10-
against_base((a, b)-> transpose(a) * transpose(b), T, (10, 15), (1, 10))
11-
against_base((a, b)-> transpose(a) * b, T, (10, 15), (10,))
12-
against_base(mul!, T, (15,), (15, 10), (10,))
4+
@test compare(*, Typ, rand(Float32, 5, 5), rand(Float32, 5, 5))
5+
@test compare(*, Typ, rand(Float32, 5, 5), rand(Float32, 5))
6+
@test compare((a, b)-> a * transpose(b), Typ, rand(Float32, 5, 5), rand(Float32, 5, 5))
7+
@test compare((c, a, b)-> mul!(c, a, transpose(b)), Typ, rand(Float32, 10, 32), rand(Float32, 10, 60), rand(Float32, 32, 60))
8+
@test compare((a, b)-> transpose(a) * b, Typ, rand(Float32, 5, 5), rand(Float32, 5, 5))
9+
@test compare((a, b)-> transpose(a) * transpose(b), Typ, rand(Float32, 10, 15), rand(Float32, 1, 10))
10+
@test compare((a, b)-> transpose(a) * b, Typ, rand(Float32, 10, 15), rand(Float32, 10))
11+
@test compare(mul!, Typ, rand(Float32, 15), rand(Float32, 15, 10), rand(Float32, 10))
1312
end
13+
1414
for T in (ComplexF32, Float32)
1515
@testset "rmul! $T" begin
16-
against_base(rmul!, Typ{T}, (13, 23), 77f0)
16+
@test compare(rmul!, Typ, rand(T, 13, 23), Ref(77f0))
1717
end
1818
end
19+
1920
@testset "gbmv" begin
2021
m, n = 10, 11
2122
A, x, y = randn(Float32, 3, n), randn(Float32, n), fill(0f0, m)

src/testsuite/broadcasting.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,60 +32,61 @@ end
3232
function broadcasting(Typ)
3333
for ET in supported_eltypes()
3434
N = 10
35-
T = Typ{ET}
3635
@testset "broadcast $ET" begin
3736
@testset "RefValue" begin
3837
cidx = rand(1:Int(N), 2*N)
3938
gidx = Typ(cidx)
40-
cy = toarray(ET, (2*N,))
39+
cy = rand(ET, 2*N)
4140
gy = Typ(cy)
4241
cres = fill(zero(ET), size(cidx))
4342
gres = Typ(cres)
4443
gres .= test_idx.(gidx, Base.RefValue(gy))
4544
cres .= test_idx.(cidx, Base.RefValue(cy))
4645
@test Array(gres) == cres
4746
end
47+
4848
@testset "Tuple" begin
49-
against_base(T, (3, N), (3, N), (N,), (N,), (N,)) do out, arr, a, b, c
50-
res2 = broadcast!(out, arr, (a, b, c)) do xx, yy
51-
xx + sum(yy)
49+
@test compare(Typ, rand(ET, 3, N), rand(ET, 3, N), rand(ET, N), rand(ET, N), rand(ET, N)) do out, arr, a, b, c
50+
broadcast!(out, arr, (a, b, c)) do xx, yy
51+
xx + first(yy)
5252
end
5353
end
5454
end
55+
5556
############
5657
# issue #27
57-
against_base((a, b)-> a .+ b, T, (4, 5, 3), (1, 5, 3))
58-
against_base((a, b)-> a .+ b, T, (4, 5, 3), (1, 5, 1))
58+
@test compare((a, b)-> a .+ b, Typ, rand(ET, 4, 5, 3), rand(ET, 1, 5, 3))
59+
@test compare((a, b)-> a .+ b, Typ, rand(ET, 4, 5, 3), rand(ET, 1, 5, 1))
5960

6061
############
6162
# issue #22
6263
dim = (32, 32)
63-
against_base(T, dim, dim, dim) do tmp, a1, a2
64+
@test compare(Typ, rand(ET, dim), rand(ET, dim), rand(ET, dim)) do tmp, a1, a2
6465
tmp .= a1 .+ a2 .* ET(2)
6566
end
6667

6768
############
6869
# issue #21
6970
if ET in (Float32, Float64)
70-
against_base((a1, a2)-> muladd.(ET(2), a1, a2), T, dim, dim)
71+
@test compare((a1, a2)-> muladd.(ET(2), a1, a2), Typ, rand(ET, dim), rand(ET, dim))
7172
#########
7273
# issue #41
7374
# The first issue is likely https://github.com/JuliaLang/julia/issues/22255
7475
# since GPUArrays adds some arguments to the function, it becomes longer longer, hitting the 12
7576
# so this wont fix for now
76-
against_base(T, dim, dim, dim, dim, dim, dim) do a1, a2, a3, a4, a5, a6
77+
@test compare(Typ, rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim)) do a1, a2, a3, a4, a5, a6
7778
@. a1 = a2 + (1.2) *((1.3)*a3 + (1.4)*a4 + (1.5)*a5 + (1.6)*a6)
7879
end
7980

80-
against_base(T, dim, dim, dim, dim) do u, uprev, duprev, ku
81+
@test compare(Typ, rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim)) do u, uprev, duprev, ku
8182
fract = ET(1//2)
8283
dt = ET(1.4)
8384
dt2 = dt^2
8485
@. u = uprev + dt*duprev + dt2*(fract*ku)
8586
end
86-
against_base((x) -> (-).(x), T, (2, 3))
87+
@test compare((x) -> (-).(x), Typ, rand(ET, 2, 3))
8788

88-
against_base(T, dim, dim, dim, dim, dim, dim) do utilde, gA, k1, k2, k3, k4
89+
@test compare(Typ, rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim), rand(ET, dim)) do utilde, gA, k1, k2, k3, k4
8990
btilde1 = ET(1)
9091
btilde2 = ET(1)
9192
btilde3 = ET(1)
@@ -95,19 +96,18 @@ function broadcasting(Typ)
9596
end
9697
end
9798

98-
against_base((x) -> fill!(x, 1), T, (3,3))
99-
against_base((x, y) -> map(+, x, y), T, (2, 3), (2, 3))
99+
@test compare((x) -> fill!(x, 1), Typ, rand(ET, 3,3))
100+
@test compare((x, y) -> map(+, x, y), Typ, rand(ET, 2, 3), rand(ET, 2, 3))
100101

101-
against_base((x) -> 2x, T, (2, 3))
102-
against_base((x, y) -> x .+ y, T, (2, 3), (1, 3))
103-
against_base((z, x, y) -> z .= x .+ y, T, (2, 3), (2, 3), (2,))
102+
@test compare((x) -> 2x, Typ, rand(ET, 2, 3))
103+
@test compare((x, y) -> x .+ y, Typ, rand(ET, 2, 3), rand(ET, 1, 3))
104+
@test compare((z, x, y) -> z .= x .+ y, Typ, rand(ET, 2, 3), rand(ET, 2, 3), rand(ET, 2))
104105

105-
T = Typ{ET}
106-
against_base(A -> A .= identity.(ET(10)), T, (40, 40))
107-
against_base(A -> test_kernel.(A, ET(10)), T, (40, 40))
108-
against_base(A -> A .* ET(10), T, (40, 40))
109-
against_base((A, B) -> A .* B, T, (40, 40), (40, 40))
110-
against_base((A, B) -> A .* B .+ ET(10), T, (40, 40), (40, 40))
106+
@test compare(A -> A .= identity.(ET(10)), Typ, rand(ET, 40, 40))
107+
@test compare(A -> test_kernel.(A, ET(10)), Typ, rand(ET, 40, 40))
108+
@test compare(A -> A .* ET(10), Typ, rand(ET, 40, 40))
109+
@test compare((A, B) -> A .* B, Typ, rand(ET, 40, 40), rand(ET, 40, 40))
110+
@test compare((A, B) -> A .* B .+ ET(10), Typ, rand(ET, 40, 40), rand(ET, 40, 40))
111111
end
112112
end
113113
end

src/testsuite/construction.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function test_construction(Typ)
2-
@testset "Construction" begin
2+
@testset "construction" begin
33
constructors(Typ)
44
conversion(Typ)
55
value_constructor(Typ)
@@ -89,7 +89,7 @@ function conversion(Typ)
8989
end
9090

9191
function value_constructor(Typ)
92-
@testset "value constructor" begin
92+
@testset "value constructors" begin
9393
for T in supported_eltypes()
9494
x = fill(zero(T), (2, 2))
9595
x1 = fill(Typ{T}, T(0), (2, 2))

src/testsuite/fft.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
function test_fft(Typ)
2-
T = Typ{ComplexF32}
32
for n = 1:3
43
@testset "FFT with ND = $n" begin
5-
dims = ntuple(i-> 40, n)
6-
against_base(fft!, T, dims)
7-
against_base(ifft!, T, dims)
4+
dims = ntuple(i -> 40, n)
5+
@test compare(fft!, Typ, rand(ComplexF32, dims))
6+
@test compare(ifft!, Typ, rand(ComplexF32, dims))
87

9-
against_base(fft, T, dims)
10-
against_base(ifft, T, dims)
8+
@test compare(fft, Typ, rand(ComplexF32, dims))
9+
@test compare(ifft, Typ, rand(ComplexF32, dims))
1110
end
1211
end
1312
end

src/testsuite/io.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
function test_io(Typ)
2-
@testset "i/o" begin
3-
@testset "Showing" begin
2+
@testset "input/output" begin
3+
@testset "showing" begin
44
io = IOBuffer()
55
A = Typ([1])
66

src/testsuite/linalg.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
function test_linalg(Typ)
2-
T = Typ{Float32}
3-
@testset "Linalg" begin
2+
@testset "linear algebra" begin
43
@testset "transpose" begin
5-
against_base(adjoint, T, (32, 32))
4+
@test compare(adjoint, Typ, rand(Float32, 32, 32))
65
end
7-
@testset "PermuteDims" begin
8-
against_base(x -> permutedims(x, (2, 1)), T, (2, 3))
9-
against_base(x -> permutedims(x, (2, 1, 3)), T, (4, 5, 6))
10-
against_base(x -> permutedims(x, (3, 1, 2)), T, (4, 5, 6))
6+
7+
@testset "permutedims" begin
8+
@test compare(x -> permutedims(x, (2, 1)), Typ, rand(Float32, 2, 3))
9+
@test compare(x -> permutedims(x, (2, 1, 3)), Typ, rand(Float32, 4, 5, 6))
10+
@test compare(x -> permutedims(x, (3, 1, 2)), Typ, rand(Float32, 4, 5, 6))
1111
end
1212
end
1313
end

0 commit comments

Comments
 (0)