|
1 |
| -@testset "Matmul" begin |
| 1 | +@testset "With/Without Transpose" begin |
2 | 2 | X = rand(40, 40)
|
3 | 3 | tol = 1e-12
|
4 |
| - |
5 | 4 | X1 = distribute(X, Blocks(10, 20))
|
6 | 5 | X2 = X1'*X1
|
7 | 6 | X3 = X1*X1'
|
8 | 7 | X4 = X1*X1
|
9 |
| - |
10 | 8 | @test norm(collect(X2) - (X' * X)) < tol
|
11 | 9 | @test norm(collect(X3) - (X * X')) < tol
|
12 | 10 | @test norm(collect(X4) - (X * X)) < tol
|
13 |
| - @test chunks(X2) |> size == (2, 2) |
14 |
| - @test chunks(X3) |> size == (4, 4) |
15 |
| - @test chunks(X4) |> size == (4, 2) |
16 |
| - @test map(x->size(x) == (20, 20), domainchunks(X2)) |> all |
17 |
| - @test map(x->size(x) == (10, 10), domainchunks(X3)) |> all |
18 |
| - @test map(x->size(x) == (10, 20), domainchunks(X4)) |> all |
19 |
| - |
20 |
| - @testset "Powers" begin |
21 |
| - x = rand(Blocks(4,4), 16, 16) |
22 |
| - @test collect(x^1) == collect(x) |
23 |
| - @test collect(x^2) == collect(x*x) |
24 |
| - @test collect(x^3) == collect(x*x*x) |
25 |
| - end |
| 11 | +end |
26 | 12 |
|
27 |
| - @testset "GEMM: $T" for T in (Float32, Float64, ComplexF32, ComplexF64) |
28 |
| - A = rand(T, 128, 128) |
29 |
| - B = rand(T, 128, 128) |
| 13 | +@testset "Powers" begin |
| 14 | + x = rand(Blocks(4,4), 16, 16) |
| 15 | + @test collect(x^1) == collect(x) |
| 16 | + @test collect(x^2) == collect(x*x) |
| 17 | + @test collect(x^3) == collect(x*x*x) |
| 18 | +end |
30 | 19 |
|
31 |
| - DA = view(A, Blocks(32, 32)) |
32 |
| - DB = view(B, Blocks(32, 32)) |
| 20 | +function test_gemm!(T, szA, szB, partA, partB) |
| 21 | + @assert szA[1] == szB[2] |
| 22 | + szC = (szA[1], szA[1]) |
| 23 | + @assert partA.blocksize[1] == partB.blocksize[2] |
| 24 | + partC = Blocks(partA.blocksize[1], partB.blocksize[2]) |
33 | 25 |
|
34 |
| - ## Out-of-place gemm |
35 |
| - # No transA, No transB |
36 |
| - DC = DA * DB |
37 |
| - C = A * B |
38 |
| - @test collect(DC) ≈ C |
| 26 | + A = rand(T, szA...) |
| 27 | + B = rand(T, szB...) |
| 28 | + |
| 29 | + DA = distribute(A, partA) |
| 30 | + DB = distribute(B, partB) |
39 | 31 |
|
| 32 | + ## Out-of-place gemm |
| 33 | + # No transA, No transB |
| 34 | + DC = DA * DB |
| 35 | + C = A * B |
| 36 | + @test collect(DC) ≈ C |
| 37 | + |
| 38 | + if szA == szB |
40 | 39 | # No transA, transB
|
41 | 40 | DC = DA * DB'
|
42 | 41 | C = A * B'
|
|
46 | 45 | DC = DA' * DB
|
47 | 46 | C = A' * B
|
48 | 47 | @test collect(DC) ≈ C
|
| 48 | + end |
49 | 49 |
|
50 |
| - # transA, transB |
51 |
| - DC = DA' * DB' |
52 |
| - C = A' * B' |
53 |
| - @test collect(DC) ≈ C |
| 50 | + # transA, transB |
| 51 | + DC = DA' * DB' |
| 52 | + C = A' * B' |
| 53 | + @test collect(DC) ≈ C |
54 | 54 |
|
55 |
| - ## In-place gemm |
56 |
| - # No transA, No transB |
57 |
| - C = zeros(T, 128, 128) |
58 |
| - DC = view(C, Blocks(32, 32)) |
59 |
| - mul!(C, A, B) |
60 |
| - mul!(DC, DA, DB) |
61 |
| - @test collect(DC) ≈ C |
| 55 | + ## In-place gemm |
| 56 | + # No transA, No transB |
| 57 | + C = zeros(T, szC...) |
| 58 | + DC = distribute(C, partC) |
| 59 | + mul!(C, A, B) |
| 60 | + mul!(DC, DA, DB) |
| 61 | + @test collect(DC) ≈ C |
62 | 62 |
|
| 63 | + if szA == szB |
63 | 64 | # No transA, transB
|
64 |
| - C = zeros(T, 128, 128) |
65 |
| - DC = view(C, Blocks(32, 32)) |
| 65 | + C = zeros(T, szC...) |
| 66 | + DC = distribute(C, partC) |
66 | 67 | mul!(C, A, B')
|
67 | 68 | mul!(DC, DA, DB')
|
68 | 69 | @test collect(DC) ≈ C
|
69 | 70 |
|
70 | 71 | # transA, No transB
|
71 |
| - C = zeros(T, 128, 128) |
72 |
| - DC = view(C, Blocks(32, 32)) |
| 72 | + C = zeros(T, szC...) |
| 73 | + DC = distribute(C, partC) |
73 | 74 | mul!(C, A', B)
|
74 | 75 | mul!(DC, DA', DB)
|
75 | 76 | @test collect(DC) ≈ C
|
| 77 | + end |
76 | 78 |
|
77 |
| - # transA, transB |
78 |
| - C = zeros(T, 128, 128) |
79 |
| - DC = view(C, Blocks(32, 32)) |
80 |
| - mul!(C, A', B') |
81 |
| - mul!(DC, DA', DB') |
82 |
| - collect(DC) ≈ C |
| 79 | + # transA, transB |
| 80 | + C = zeros(T, szA[2], szA[2]) |
| 81 | + DC = distribute(C, partC) |
| 82 | + mul!(C, A', B') |
| 83 | + mul!(DC, DA', DB') |
| 84 | + collect(DC) ≈ C |
83 | 85 |
|
| 86 | + if szA == szB |
84 | 87 | ## Out-of-place syrk
|
85 | 88 | # No trans, trans
|
86 | 89 | DC = DA * DA'
|
|
94 | 97 |
|
95 | 98 | ## In-place syrk
|
96 | 99 | # No trans, trans
|
97 |
| - C = zeros(T, 128, 128) |
98 |
| - DC = distribute(C, Blocks(32, 32)) |
| 100 | + C = zeros(T, szC...) |
| 101 | + DC = distribute(C, partC) |
99 | 102 | mul!(C, A, A')
|
100 | 103 | mul!(DC, DA, DA')
|
101 | 104 | @test collect(DC) ≈ C
|
102 | 105 |
|
103 | 106 | # trans, No trans
|
104 |
| - C = zeros(T, 128, 128) |
105 |
| - DC = distribute(C, Blocks(32, 32)) |
| 107 | + C = zeros(T, szC...) |
| 108 | + DC = distribute(C, partC) |
106 | 109 | mul!(C, A', A)
|
107 | 110 | mul!(DC, DA', DA)
|
108 | 111 | @test collect(DC) ≈ C
|
109 | 112 | end
|
110 | 113 | end
|
| 114 | + |
| 115 | +_sizes_to_test = [ |
| 116 | + (4, 4), |
| 117 | + (7, 7), |
| 118 | + (12, 12), |
| 119 | + (16, 16), |
| 120 | +] |
| 121 | +size_sets_to_test = map(_sizes_to_test) do sz |
| 122 | + rows, cols = sz |
| 123 | + return [ |
| 124 | + (rows, cols) => (cols, rows), |
| 125 | + (rows ÷ 2, cols) => (cols, rows ÷ 2), |
| 126 | + (rows, cols ÷ 2) => (cols ÷ 2, rows), |
| 127 | + ] |
| 128 | +end |
| 129 | +sizes_to_test = vcat(size_sets_to_test...) |
| 130 | +part_sets_to_test = map(_sizes_to_test) do sz |
| 131 | + rows, cols = sz |
| 132 | + return [ |
| 133 | + Blocks(rows, cols) => Blocks(cols, rows), |
| 134 | + Blocks(rows ÷ 2, cols) => Blocks(cols, rows ÷ 2), |
| 135 | + Blocks(rows, cols ÷ 2) => Blocks(cols ÷ 2, rows), |
| 136 | + ] |
| 137 | +end |
| 138 | +parts_to_test = vcat(part_sets_to_test...) |
| 139 | +@testset "Size=$szA*$szB" for (szA, szB) in sizes_to_test |
| 140 | + @testset "Partitioning=$partA*$partB" for (partA,partB) in parts_to_test |
| 141 | + @testset "T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64) |
| 142 | + test_gemm!(T, szA, szB, partA, partB) |
| 143 | + end |
| 144 | + end |
| 145 | +end |
0 commit comments