Skip to content

Commit 290e5cf

Browse files
committed
tests: Expand matmul tests
1 parent 6925870 commit 290e5cf

File tree

1 file changed

+86
-51
lines changed

1 file changed

+86
-51
lines changed

test/array/linalg/matmul.jl

Lines changed: 86 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,41 @@
1-
@testset "Matmul" begin
1+
@testset "With/Without Transpose" begin
22
X = rand(40, 40)
33
tol = 1e-12
4-
54
X1 = distribute(X, Blocks(10, 20))
65
X2 = X1'*X1
76
X3 = X1*X1'
87
X4 = X1*X1
9-
108
@test norm(collect(X2) - (X' * X)) < tol
119
@test norm(collect(X3) - (X * X')) < tol
1210
@test norm(collect(X4) - (X * X)) < tol
13-
@test chunks(X2) |> size == (2, 2)
14-
@test chunks(X3) |> size == (4, 4)
15-
@test chunks(X4) |> size == (4, 2)
16-
@test map(x->size(x) == (20, 20), domainchunks(X2)) |> all
17-
@test map(x->size(x) == (10, 10), domainchunks(X3)) |> all
18-
@test map(x->size(x) == (10, 20), domainchunks(X4)) |> all
19-
20-
@testset "Powers" begin
21-
x = rand(Blocks(4,4), 16, 16)
22-
@test collect(x^1) == collect(x)
23-
@test collect(x^2) == collect(x*x)
24-
@test collect(x^3) == collect(x*x*x)
25-
end
11+
end
2612

27-
@testset "GEMM: $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
28-
A = rand(T, 128, 128)
29-
B = rand(T, 128, 128)
13+
@testset "Powers" begin
14+
x = rand(Blocks(4,4), 16, 16)
15+
@test collect(x^1) == collect(x)
16+
@test collect(x^2) == collect(x*x)
17+
@test collect(x^3) == collect(x*x*x)
18+
end
3019

31-
DA = view(A, Blocks(32, 32))
32-
DB = view(B, Blocks(32, 32))
20+
function test_gemm!(T, szA, szB, partA, partB)
21+
@assert szA[1] == szB[2]
22+
szC = (szA[1], szA[1])
23+
@assert partA.blocksize[1] == partB.blocksize[2]
24+
partC = Blocks(partA.blocksize[1], partB.blocksize[2])
3325

34-
## Out-of-place gemm
35-
# No transA, No transB
36-
DC = DA * DB
37-
C = A * B
38-
@test collect(DC) C
26+
A = rand(T, szA...)
27+
B = rand(T, szB...)
28+
29+
DA = distribute(A, partA)
30+
DB = distribute(B, partB)
3931

32+
## Out-of-place gemm
33+
# No transA, No transB
34+
DC = DA * DB
35+
C = A * B
36+
@test collect(DC) C
37+
38+
if szA == szB
4039
# No transA, transB
4140
DC = DA * DB'
4241
C = A * B'
@@ -46,41 +45,45 @@
4645
DC = DA' * DB
4746
C = A' * B
4847
@test collect(DC) C
48+
end
4949

50-
# transA, transB
51-
DC = DA' * DB'
52-
C = A' * B'
53-
@test collect(DC) C
50+
# transA, transB
51+
DC = DA' * DB'
52+
C = A' * B'
53+
@test collect(DC) C
5454

55-
## In-place gemm
56-
# No transA, No transB
57-
C = zeros(T, 128, 128)
58-
DC = view(C, Blocks(32, 32))
59-
mul!(C, A, B)
60-
mul!(DC, DA, DB)
61-
@test collect(DC) C
55+
## In-place gemm
56+
# No transA, No transB
57+
C = zeros(T, szC...)
58+
DC = distribute(C, partC)
59+
mul!(C, A, B)
60+
mul!(DC, DA, DB)
61+
@test collect(DC) C
6262

63+
if szA == szB
6364
# No transA, transB
64-
C = zeros(T, 128, 128)
65-
DC = view(C, Blocks(32, 32))
65+
C = zeros(T, szC...)
66+
DC = distribute(C, partC)
6667
mul!(C, A, B')
6768
mul!(DC, DA, DB')
6869
@test collect(DC) C
6970

7071
# transA, No transB
71-
C = zeros(T, 128, 128)
72-
DC = view(C, Blocks(32, 32))
72+
C = zeros(T, szC...)
73+
DC = distribute(C, partC)
7374
mul!(C, A', B)
7475
mul!(DC, DA', DB)
7576
@test collect(DC) C
77+
end
7678

77-
# transA, transB
78-
C = zeros(T, 128, 128)
79-
DC = view(C, Blocks(32, 32))
80-
mul!(C, A', B')
81-
mul!(DC, DA', DB')
82-
collect(DC) C
79+
# transA, transB
80+
C = zeros(T, szA[2], szA[2])
81+
DC = distribute(C, partC)
82+
mul!(C, A', B')
83+
mul!(DC, DA', DB')
84+
collect(DC) C
8385

86+
if szA == szB
8487
## Out-of-place syrk
8588
# No trans, trans
8689
DC = DA * DA'
@@ -94,17 +97,49 @@
9497

9598
## In-place syrk
9699
# No trans, trans
97-
C = zeros(T, 128, 128)
98-
DC = distribute(C, Blocks(32, 32))
100+
C = zeros(T, szC...)
101+
DC = distribute(C, partC)
99102
mul!(C, A, A')
100103
mul!(DC, DA, DA')
101104
@test collect(DC) C
102105

103106
# trans, No trans
104-
C = zeros(T, 128, 128)
105-
DC = distribute(C, Blocks(32, 32))
107+
C = zeros(T, szC...)
108+
DC = distribute(C, partC)
106109
mul!(C, A', A)
107110
mul!(DC, DA', DA)
108111
@test collect(DC) C
109112
end
110113
end
114+
115+
_sizes_to_test = [
116+
(4, 4),
117+
(7, 7),
118+
(12, 12),
119+
(16, 16),
120+
]
121+
size_sets_to_test = map(_sizes_to_test) do sz
122+
rows, cols = sz
123+
return [
124+
(rows, cols) => (cols, rows),
125+
(rows ÷ 2, cols) => (cols, rows ÷ 2),
126+
(rows, cols ÷ 2) => (cols ÷ 2, rows),
127+
]
128+
end
129+
sizes_to_test = vcat(size_sets_to_test...)
130+
part_sets_to_test = map(_sizes_to_test) do sz
131+
rows, cols = sz
132+
return [
133+
Blocks(rows, cols) => Blocks(cols, rows),
134+
Blocks(rows ÷ 2, cols) => Blocks(cols, rows ÷ 2),
135+
Blocks(rows, cols ÷ 2) => Blocks(cols ÷ 2, rows),
136+
]
137+
end
138+
parts_to_test = vcat(part_sets_to_test...)
139+
@testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test
140+
@testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test
141+
@testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
142+
test_gemm!(T, szA, szB, partA, partB)
143+
end
144+
end
145+
end

0 commit comments

Comments
 (0)