|
14 | 14 | // RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s |
15 | 15 |
|
16 | 16 | #packed_maps = [ |
17 | | - affine_map<(d0, d1, d2) -> (d0, d2)>, |
18 | | - affine_map<(d0, d1, d2) -> (d1, d2)>, |
19 | | - affine_map<(d0, d1, d2) -> (d0, d1)> |
| 17 | + affine_map<(m, n, k) -> (m, k)>, |
| 18 | + affine_map<(m, n, k) -> (n, k)>, |
| 19 | + affine_map<(m, n, k) -> (m, n)> |
20 | 20 | ] |
21 | 21 |
|
22 | 22 | // |
|
38 | 38 | // * RHS: vector<[N]x8xi8> |
39 | 39 | // * ACC, OUT: vector<Mx[N]xi32> |
40 | 40 | // Note that the RHS is transposed. |
| 41 | +// This data layout makes it efficient to load data into SVE |
| 42 | +// registers in the layout expected by FEAT_I8MM instructions. |
| 43 | +// Such a `vector.contract` is representative of the code we aim to generate |
| 44 | +// by scalable vectorisation of `linalg.mmt4d`. |
41 | 45 | // See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp |
42 | 46 | // for more information and rationale about these shapes. |
43 | 47 | // |
@@ -150,7 +154,7 @@ func.func @test_smmla() { |
150 | 154 | %rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8> |
151 | 155 | %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> |
152 | 156 |
|
153 | | - // Matrix multiplication |
| 157 | + // Matrix multiplication and accumulate with transposed RHS. |
154 | 158 | %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> |
155 | 159 | %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> |
156 | 160 | %2 = vector.contract {indexing_maps = #packed_maps, |
@@ -216,7 +220,7 @@ func.func @test_ummla() { |
216 | 220 | %rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8> |
217 | 221 | %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> |
218 | 222 |
|
219 | | - // Matrix multiplication |
| 223 | + // Matrix multiplication and accumulate with transposed RHS. |
220 | 224 | %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32> |
221 | 225 | %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> |
222 | 226 | %2 = vector.contract {indexing_maps = #packed_maps, |
@@ -283,7 +287,7 @@ func.func @test_usmmla() { |
283 | 287 | %rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8> |
284 | 288 | %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> |
285 | 289 |
|
286 | | - // Matrix multiplication |
| 290 | + // Matrix multiplication and accumulate with transposed RHS. |
287 | 291 | %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32> |
288 | 292 | %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> |
289 | 293 | %2 = vector.contract {indexing_maps = #packed_maps, |
@@ -351,7 +355,7 @@ func.func @test_summla() { |
351 | 355 | %rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8> |
352 | 356 | %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8> |
353 | 357 |
|
354 | | - // Matrix multiplication |
| 358 | + // Matrix multiplication and accumulate with transposed RHS. |
355 | 359 | %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> |
356 | 360 | %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32> |
357 | 361 | %2 = vector.contract {indexing_maps = #packed_maps, |
|
0 commit comments