Skip to content

Commit 4de82c1

Browse files
[fixup] Commenting
1 parent a4564e1 commit 4de82c1

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
1515

1616
#packed_maps = [
17-
affine_map<(d0, d1, d2) -> (d0, d2)>,
18-
affine_map<(d0, d1, d2) -> (d1, d2)>,
19-
affine_map<(d0, d1, d2) -> (d0, d1)>
17+
affine_map<(m, n, k) -> (m, k)>,
18+
affine_map<(m, n, k) -> (n, k)>,
19+
affine_map<(m, n, k) -> (m, n)>
2020
]
2121

2222
//
@@ -38,6 +38,10 @@
3838
// * RHS: vector<[N]x8xi8>
3939
// * ACC, OUT: vector<Mx[N]xi32>
4040
// Note that the RHS is transposed.
41+
// This data layout makes it efficient to load data into SVE
42+
// registers in the layout expected by FEAT_I8MM instructions.
43+
// Such a `vector.contract` is representative of the code we aim to generate
44+
// by scalable vectorisation of `linalg.mmt4d`.
4145
// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
4246
// for more information and rationale about these shapes.
4347
//
@@ -150,7 +154,7 @@ func.func @test_smmla() {
150154
%rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8>
151155
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
152156

153-
// Matrix multiplication
157+
// Matrix multiplication and accumulate with transposed RHS.
154158
%0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
155159
%1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
156160
%2 = vector.contract {indexing_maps = #packed_maps,
@@ -216,7 +220,7 @@ func.func @test_ummla() {
216220
%rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8>
217221
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
218222

219-
// Matrix multiplication
223+
// Matrix multiplication and accumulate with transposed RHS.
220224
%0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
221225
%1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
222226
%2 = vector.contract {indexing_maps = #packed_maps,
@@ -283,7 +287,7 @@ func.func @test_usmmla() {
283287
%rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8>
284288
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
285289

286-
// Matrix multiplication
290+
// Matrix multiplication and accumulate with transposed RHS.
287291
%0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
288292
%1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
289293
%2 = vector.contract {indexing_maps = #packed_maps,
@@ -351,7 +355,7 @@ func.func @test_summla() {
351355
%rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8>
352356
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
353357

354-
// Matrix multiplication
358+
// Matrix multiplication and accumulate with transposed RHS.
355359
%0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
356360
%1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
357361
%2 = vector.contract {indexing_maps = #packed_maps,

0 commit comments

Comments
 (0)