Skip to content

Commit a620345

Browse files
committed
tests: Split array testsets
1 parent f59dd37 commit a620345

File tree

6 files changed

+134
-140
lines changed

6 files changed

+134
-140
lines changed

test/array.jl renamed to test/array/core.jl

Lines changed: 9 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
using LinearAlgebra, SparseArrays, Random, SharedArrays
2-
import Dagger: DArray, chunks, domainchunks, treereduce_nd
3-
import Distributed: myid, procs
4-
import Statistics: mean, var, std
5-
import OnlineStats
6-
71
@testset "treereduce_nd" begin
82
xs = rand(1:10, 8,8,8)
93
concats = [(x...)->cat(x..., dims=n) for n in 1:3]
@@ -80,52 +74,6 @@ end
8074
end
8175
end
8276

83-
function test_mapreduce(f, init_func; no_init=true, zero_init=zero,
84-
types=(Int32, Int64, Float32, Float64),
85-
cmp=isapprox)
86-
@testset "$T" for T in types
87-
X = init_func(Blocks(10, 10), T, 100, 100)
88-
inits = ()
89-
if no_init
90-
inits = (inits..., nothing)
91-
end
92-
if zero_init !== nothing
93-
inits = (inits..., zero_init(T))
94-
end
95-
@testset "dims=$dims" for dims in (Colon(), 1, 2, (1,), (2,))
96-
@testset "init=$init" for init in inits
97-
if init === nothing
98-
if dims == Colon()
99-
@test cmp(f(X; dims), f(collect(X); dims))
100-
else
101-
@test cmp(collect(f(X; dims)), f(collect(X); dims))
102-
end
103-
else
104-
if dims == Colon()
105-
@test cmp(f(X; dims, init), f(collect(X); dims, init))
106-
else
107-
@test cmp(collect(f(X; dims, init)), f(collect(X); dims, init))
108-
end
109-
end
110-
end
111-
end
112-
end
113-
end
114-
115-
# Base
116-
@testset "reduce" test_mapreduce((X; dims, init=Base._InitialValue())->reduce(+, X; dims, init), ones)
117-
@testset "mapreduce" test_mapreduce((X; dims, init=Base._InitialValue())->mapreduce(x->x+1, +, X; dims, init), ones)
118-
@testset "sum" test_mapreduce(sum, ones)
119-
@testset "prod" test_mapreduce(prod, rand)
120-
@testset "minimum" test_mapreduce(minimum, rand)
121-
@testset "maximum" test_mapreduce(maximum, rand)
122-
@testset "extrema" test_mapreduce(extrema, rand; cmp=Base.:(==), zero_init=T->(zero(T), zero(T)))
123-
124-
# Statistics
125-
@testset "mean" test_mapreduce(mean, rand; zero_init=nothing, types=(Float32, Float64))
126-
@testset "var" test_mapreduce(var, rand; zero_init=nothing, types=(Float32, Float64))
127-
@testset "std" test_mapreduce(std, rand; zero_init=nothing, types=(Float32, Float64))
128-
12977
@testset "broadcast" begin
13078
X1 = rand(Blocks(10), 100)
13179
X2 = X1 .* 3.4
@@ -138,7 +86,7 @@ end
13886

13987
@testset "distributing an array" begin
14088
function test_dist(X)
141-
X1 = distribute(X, Blocks(10, 20))
89+
X1 = Distribute(Blocks(10, 20), X)
14290
Xc = fetch(X1)
14391
@test Xc isa DArray{eltype(X),ndims(X)}
14492
@test Xc == X
@@ -147,7 +95,7 @@ end
14795
@test map(x->size(x) == (10, 20), domainchunks(Xc)) |> all
14896
end
14997
x = [1 2; 3 4]
150-
@test distribute(x, Blocks(1,1)) == x
98+
@test Distribute(Blocks(1,1), x) == x
15199
test_dist(rand(100, 100))
152100
test_dist(sprand(100, 100, 0.1))
153101

@@ -171,55 +119,26 @@ end
171119
test_transpose(sprand(100, 120, 0.1))
172120
end
173121

174-
@testset "matrix-matrix multiply" begin
175-
function test_mul(X)
176-
tol = 1e-12
177-
X1 = distribute(X, Blocks(10, 20))
178-
@test_throws DimensionMismatch X1*X1
179-
X2 = X1'*X1
180-
X3 = X1*X1'
181-
@test norm(collect(X2) - X'X) < tol
182-
@test norm(collect(X3) - X*X') < tol
183-
@test chunks(X2) |> size == (2, 2)
184-
@test chunks(X3) |> size == (4, 4)
185-
@test map(x->size(x) == (20, 20), domainchunks(X2)) |> all
186-
@test map(x->size(x) == (10, 10), domainchunks(X3)) |> all
187-
end
188-
test_mul(rand(40, 40))
189-
190-
x = rand(10,10)
191-
X = distribute(x, Blocks(3,3))
192-
y = rand(10)
193-
@test norm(collect(X*y) - x*y) < 1e-13
194-
end
195-
196-
@testset "matrix powers" begin
197-
x = rand(Blocks(4,4), 16, 16)
198-
@test collect(x^1) == collect(x)
199-
@test collect(x^2) == collect(x*x)
200-
@test collect(x^3) == collect(x*x*x)
201-
end
202-
203122
@testset "concat" begin
204123
m = rand(75,75)
205-
x = distribute(m, Blocks(10,20))
206-
y = distribute(m, Blocks(10,10))
124+
x = Distribute(Blocks(10,20), m)
125+
y = Distribute(Blocks(10,10), m)
207126
@test hcat(m,m) == collect(hcat(x,x)) == collect(hcat(x,y))
208127
@test vcat(m,m) == collect(vcat(x,x))
209128
@test_throws DimensionMismatch vcat(x,y)
210129
end
211130

212131
@testset "scale" begin
213132
x = rand(10,10)
214-
X = distribute(x, Blocks(3,3))
133+
X = Distribute(Blocks(3,3), x)
215134
y = rand(10)
216135

217136
@test Diagonal(y)*x == collect(Diagonal(y)*X)
218137
end
219138

220139
@testset "Getindex" begin
221140
function test_getindex(x)
222-
X = distribute(x, Blocks(3,3))
141+
X = Distribute(Blocks(3,3), x)
223142
@test collect(X[3:8, 2:7]) == x[3:8, 2:7]
224143
ragged_idx = [1,2,9,7,6,2,4,5]
225144
@test collect(X[ragged_idx, 2:7]) == x[ragged_idx, 2:7]
@@ -248,7 +167,7 @@ end
248167

249168

250169
@testset "cleanup" begin
251-
X = distribute(rand(10,10), Blocks(10,10))
170+
X = Distribute(Blocks(10,10), rand(10,10))
252171
@test collect(sin.(X)) == collect(sin.(X))
253172
end
254173

@@ -269,7 +188,7 @@ end
269188
x=rand(10,10)
270189
y=copy(x)
271190
y[3:8, 2:7] .= 1.0
272-
X = distribute(x, Blocks(3,3))
191+
X = Distribute(Blocks(3,3), x)
273192
@test collect(setindex(X,1.0, 3:8, 2:7)) == y
274193
@test collect(X) == x
275194
end
@@ -292,7 +211,7 @@ end
292211
@test collect(sort(y)) == x
293212

294213
x = ones(10)
295-
y = distribute(x, Blocks(3))
214+
y = Distribute(Blocks(3), x)
296215
@test_broken map(x->length(collect(x)), sort(y).chunks) == [3,3,3,1]
297216
end
298217

test/array/linalg/cholesky.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
@testset "$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
2+
D = rand(Blocks(4, 4), T, 32, 32)
3+
if !(T <: Complex)
4+
@test !issymmetric(D)
5+
end
6+
@test !ishermitian(D)
7+
8+
A = rand(T, 128, 128)
9+
A = A * A'
10+
A[diagind(A)] .+= size(A, 1)
11+
DA = view(A, Blocks(32, 32))
12+
if !(T <: Complex)
13+
@test issymmetric(DA)
14+
end
15+
@test ishermitian(DA)
16+
17+
# Out-of-place
18+
chol_A = cholesky(A)
19+
chol_DA = cholesky(DA)
20+
@test chol_DA isa Cholesky
21+
@test chol_A.L chol_DA.L
22+
@test chol_A.U chol_DA.U
23+
24+
# In-place
25+
A_copy = copy(A)
26+
chol_A = cholesky!(A_copy)
27+
chol_DA = cholesky!(DA)
28+
@test chol_DA isa Cholesky
29+
@test chol_A.L chol_DA.L
30+
@test chol_A.U chol_DA.U
31+
# Check that changes propagated to A
32+
@test UpperTriangular(collect(DA)) UpperTriangular(collect(A))
33+
34+
# Non-PosDef matrix
35+
A = rand(T, 128, 128)
36+
A = A * A'
37+
A[diagind(A)] .+= size(A, 1)
38+
A[1, 1] = -100
39+
DA = view(A, Blocks(32, 32))
40+
if !(T <: Complex)
41+
@test issymmetric(DA)
42+
end
43+
@test ishermitian(DA)
44+
@test_throws_unwrap PosDefException cholesky(DA).U
45+
end

test/linalg.jl renamed to test/array/linalg/matmul.jl

Lines changed: 25 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,29 @@
1-
using LinearAlgebra
1+
@testset "Matmul" begin
2+
X = rand(40, 40)
3+
tol = 1e-12
4+
5+
X1 = distribute(X, Blocks(10, 20))
6+
X2 = X1'*X1
7+
X3 = X1*X1'
8+
X4 = X1*X1
9+
10+
@test norm(collect(X2) - (X' * X)) < tol
11+
@test norm(collect(X3) - (X * X')) < tol
12+
@test norm(collect(X4) - (X * X)) < tol
13+
@test chunks(X2) |> size == (2, 2)
14+
@test chunks(X3) |> size == (4, 4)
15+
@test chunks(X4) |> size == (4, 2)
16+
@test map(x->size(x) == (20, 20), domainchunks(X2)) |> all
17+
@test map(x->size(x) == (10, 10), domainchunks(X3)) |> all
18+
@test map(x->size(x) == (10, 20), domainchunks(X4)) |> all
19+
20+
@testset "Powers" begin
21+
x = rand(Blocks(4,4), 16, 16)
22+
@test collect(x^1) == collect(x)
23+
@test collect(x^2) == collect(x*x)
24+
@test collect(x^3) == collect(x*x*x)
25+
end
226

3-
@testset "Linear Algebra" begin
427
@testset "GEMM: $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
528
A = rand(T, 128, 128)
629
B = rand(T, 128, 128)
@@ -84,50 +107,4 @@ using LinearAlgebra
84107
mul!(DC, DA', DA)
85108
@test collect(DC) C
86109
end
87-
88-
@testset "Cholesky: $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
89-
D = rand(Blocks(4, 4), T, 32, 32)
90-
if !(T <: Complex)
91-
@test !issymmetric(D)
92-
end
93-
@test !ishermitian(D)
94-
95-
A = rand(T, 128, 128)
96-
A = A * A'
97-
A[diagind(A)] .+= size(A, 1)
98-
DA = view(A, Blocks(32, 32))
99-
if !(T <: Complex)
100-
@test issymmetric(DA)
101-
end
102-
@test ishermitian(DA)
103-
104-
# Out-of-place
105-
chol_A = cholesky(A)
106-
chol_DA = cholesky(DA)
107-
@test chol_DA isa Cholesky
108-
@test chol_A.L chol_DA.L
109-
@test chol_A.U chol_DA.U
110-
111-
# In-place
112-
A_copy = copy(A)
113-
chol_A = cholesky!(A_copy)
114-
chol_DA = cholesky!(DA)
115-
@test chol_DA isa Cholesky
116-
@test chol_A.L chol_DA.L
117-
@test chol_A.U chol_DA.U
118-
# Check that changes propagated to A
119-
@test UpperTriangular(collect(DA)) UpperTriangular(collect(A))
120-
121-
# Non-PosDef matrix
122-
A = rand(T, 128, 128)
123-
A = A * A'
124-
A[diagind(A)] .+= size(A, 1)
125-
A[1, 1] = -100
126-
DA = view(A, Blocks(32, 32))
127-
if !(T <: Complex)
128-
@test issymmetric(DA)
129-
end
130-
@test ishermitian(DA)
131-
@test_throws_unwrap PosDefException cholesky(DA).U
132-
end
133110
end

test/array/mapreduce.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
function test_mapreduce(f, init_func; no_init=true, zero_init=zero,
2+
types=(Int32, Int64, Float32, Float64),
3+
cmp=isapprox)
4+
@testset "$T" for T in types
5+
X = init_func(Blocks(10, 10), T, 100, 100)
6+
inits = ()
7+
if no_init
8+
inits = (inits..., nothing)
9+
end
10+
if zero_init !== nothing
11+
inits = (inits..., zero_init(T))
12+
end
13+
@testset "dims=$dims" for dims in (Colon(), 1, 2, (1,), (2,))
14+
@testset "init=$init" for init in inits
15+
if init === nothing
16+
if dims == Colon()
17+
@test cmp(f(X; dims), f(collect(X); dims))
18+
else
19+
@test cmp(collect(f(X; dims)), f(collect(X); dims))
20+
end
21+
else
22+
if dims == Colon()
23+
@test cmp(f(X; dims, init), f(collect(X); dims, init))
24+
else
25+
@test cmp(collect(f(X; dims, init)), f(collect(X); dims, init))
26+
end
27+
end
28+
end
29+
end
30+
end
31+
end
32+
33+
# Base
34+
@testset "reduce" test_mapreduce((X; dims, init=Base._InitialValue())->reduce(+, X; dims, init), ones)
35+
@testset "mapreduce" test_mapreduce((X; dims, init=Base._InitialValue())->mapreduce(x->x+1, +, X; dims, init), ones)
36+
@testset "sum" test_mapreduce(sum, ones)
37+
@testset "prod" test_mapreduce(prod, rand)
38+
@testset "minimum" test_mapreduce(minimum, rand)
39+
@testset "maximum" test_mapreduce(maximum, rand)
40+
@testset "extrema" test_mapreduce(extrema, rand; cmp=Base.:(==), zero_init=T->(zero(T), zero(T)))
41+
42+
# Statistics
43+
@testset "mean" test_mapreduce(mean, rand; zero_init=nothing, types=(Float32, Float64))
44+
@testset "var" test_mapreduce(var, rand; zero_init=nothing, types=(Float32, Float64))
45+
@testset "std" test_mapreduce(std, rand; zero_init=nothing, types=(Float32, Float64))

test/imports.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
using LinearAlgebra, SparseArrays, Random, SharedArrays
2+
import Dagger: DArray, chunks, domainchunks, treereduce_nd
3+
import Distributed: myid, procs
4+
import Statistics: mean, var, std
5+
import OnlineStats

test/runtests.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ tests = [
1111
("Task Queues", "task-queues.jl"),
1212
("Datadeps", "datadeps.jl"),
1313
("Domain Utilities", "domain.jl"),
14-
("Array", "array.jl"),
15-
("Linear Algebra", "linalg.jl"),
14+
("Array - Core", "array/core.jl"),
15+
("Array - MapReduce", "array/mapreduce.jl"),
16+
("Array - LinearAlgebra - Matmul", "array/linalg/matmul.jl"),
17+
("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"),
1618
("Caching", "cache.jl"),
1719
("Disk Caching", "diskcaching.jl"),
1820
("File IO", "file-io.jl"),
@@ -70,6 +72,7 @@ end
7072
using Distributed
7173
addprocs(3)
7274

75+
include("imports.jl")
7376
include("util.jl")
7477
include("fakeproc.jl")
7578

0 commit comments

Comments
 (0)