Skip to content

Commit a4564e1

Browse files
[fixup] More commenting
1 parent 245b6bc commit a4564e1

File tree

1 file changed

+88
-4
lines changed

1 file changed

+88
-4
lines changed

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,37 @@
1919
affine_map<(d0, d1, d2) -> (d0, d1)>
2020
]
2121

22-
func.func private @setArmVLBits(%bits : i32)
23-
func.func private @printMemrefI32(%ptr : memref<*xi32>)
24-
22+
//
23+
// Test the lowering of `vector.contract` using the `LowerContractionToSVEI8MMPattern`
24+
//
25+
// The operation that the `vector.contract` in this test performs is matrix
26+
// multiplication with accumulate
27+
// OUT = ACC + LHS * RHS
28+
// of two 8-bit integer matrices LHS and RHS, and a 32-bit integer matrix ACC
29+
// into a 32-bit integer matrix OUT. The LHS and RHS can be sign- or zero- extended,
30+
// this test covers all the possible variants.
31+
//
32+
// Tested are calculations as well as that the relevant `ArmSVE` dialect
33+
// operations ('arm_sve.smmla`, arm_sve.ummla`, etc) are emitted.
34+
//
35+
// That pattern above handles (therefore this test prepares) input/output vectors with
36+
// specific shapes:
37+
// * LHS: vector<Mx8xi8>
38+
// * RHS: vector<[N]x8xi8>
39+
// * ACC, OUT: vector<Mx[N]xi32>
40+
// Note that the RHS is transposed.
41+
// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
42+
// for more information and rationale about these shapes.
43+
//
44+
// In this specific test we use M == 4 and N == 4
45+
//
46+
47+
// Allocate and initialise a memref containing test data for use as the ACC
48+
// operand. The memref has one dynamic dimension whose extent depends on the
49+
// runtime value of VSCALE.
50+
//
51+
// The input parameter `%in` is a vector that is replicated VSCALE times
52+
// across the columns of the memref.
2553
func.func private @prepareAccTestData(%in: vector<4x4xi32>) -> memref<4x?xi32> {
2654
%c0 = arith.constant 0 : index
2755
%c1 = arith.constant 1 : index
@@ -39,6 +67,9 @@ func.func private @prepareAccTestData(%in: vector<4x4xi32>) -> memref<4x?xi32> {
3967
return %mem : memref<4x?xi32>
4068
}
4169

70+
// Allocate and initialise a memref containing test data for use as the LHS
71+
// operand. This function just writes the parameter `%in` into the memref.
72+
// The size of the LHS does not depends on VSCALE.
4273
func.func private @prepareLHSTestData(%in: vector<4x8xi8>) -> memref<4x8xi8> {
4374
%c0 = arith.constant 0 : index
4475
%c0_i8 = arith.constant 0 : i8
@@ -49,6 +80,15 @@ func.func private @prepareLHSTestData(%in: vector<4x8xi8>) -> memref<4x8xi8> {
4980
return %mem : memref<4x8xi8>
5081
}
5182

83+
// Allocate and initialise a memref containing test data for use as the RHS
84+
// operand. The memref has one dynamic dimension whose extent depends on the
85+
// runtime value of VSCALE.
86+
//
87+
// The input parameter `%in` is a vector that is replicated VSCALE times
88+
// across the rows of the memref.
89+
//
90+
// For convenience, flatten the memref, since the RHS vector is read first as a
91+
// single-dimensional scalable vector and then cast into [N]x8 shape.
5292
func.func private @prepareRHSTestData(%in: vector<4x8xi8>) -> memref<?xi8> {
5393
%c0 = arith.constant 0 : index
5494
%c1 = arith.constant 1 : index
@@ -67,6 +107,9 @@ func.func private @prepareRHSTestData(%in: vector<4x8xi8>) -> memref<?xi8> {
67107
return %mem_out : memref<?xi8>
68108
}
69109

110+
// Test the operation where both LHS and RHS are interpreted as signed, hence
111+
// we ultimately emit and execute the `smmla` instruction.
112+
70113
// CHECK-IR-LABEL: llvm.func @test_smmla
71114
// CHECK-IR-COUNT-4: arm_sve.intr.smmla
72115
func.func @test_smmla() {
@@ -84,7 +127,7 @@ func.func @test_smmla() {
84127
%acc_mem = func.call @prepareAccTestData(%acc_cst) : (vector<4x4xi32>) -> memref<4x?xi32>
85128
%acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_i32 {in_bounds = [true, true]} : memref<4x?xi32>, vector<4x[4]xi32>
86129

87-
// Workaround for a crash, see https://github.com/llvm/llvm-project/issues/143670
130+
// FIXME: Workaround for a crash, see https://github.com/llvm/llvm-project/issues/143670
88131
%acc_cast = memref.cast %acc_mem : memref<4x?xi32> to memref<*xi32>
89132
call @printMemrefI32(%acc_cast) : (memref<*xi32>) -> ()
90133

@@ -126,9 +169,17 @@ func.func @test_smmla() {
126169
vector.print %u2 : vector<[4]xi32>
127170
vector.print %u3 : vector<[4]xi32>
128171

172+
// Deallocate the buffers.
173+
memref.dealloc %acc_mem : memref<4x?xi32>
174+
memref.dealloc %lhs_mem : memref<4x8xi8>
175+
memref.dealloc %rhs_mem : memref<?xi8>
176+
129177
return
130178
}
131179

180+
// Test the operation where both LHS and RHS are interpreted as unsigned, hence
181+
// we ultimately emit and execute the `ummla` instruction.
182+
132183
// CHECK-IR-LABEL: llvm.func @test_ummla
133184
// CHECK-IR-COUNT-4: arm_sve.intr.ummla
134185
func.func @test_ummla() {
@@ -184,9 +235,18 @@ func.func @test_ummla() {
184235
vector.print %u2 : vector<[4]xi32>
185236
vector.print %u3 : vector<[4]xi32>
186237

238+
// Deallocate the buffers.
239+
memref.dealloc %acc_mem : memref<4x?xi32>
240+
memref.dealloc %lhs_mem : memref<4x8xi8>
241+
memref.dealloc %rhs_mem : memref<?xi8>
242+
187243
return
188244
}
189245

246+
// Test the operation where LHS is interpreted as unsigned and RHS is
247+
// interpreted as signed, hence we ultimately emit and execute the `usmmla`
248+
// instruction.
249+
190250
// CHECK-IR-LABEL: llvm.func @test_usmmla
191251
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
192252
func.func @test_usmmla() {
@@ -242,9 +302,19 @@ func.func @test_usmmla() {
242302
vector.print %u2 : vector<[4]xi32>
243303
vector.print %u3 : vector<[4]xi32>
244304

305+
// Deallocate the buffers.
306+
memref.dealloc %acc_mem : memref<4x?xi32>
307+
memref.dealloc %lhs_mem : memref<4x8xi8>
308+
memref.dealloc %rhs_mem : memref<?xi8>
309+
245310
return
246311
}
247312

313+
// Test the operation where LHS is interpreted as signed and RHS is interpreted
314+
// as unsigned. In this test we ultimately emit end execute the `usmmla`
315+
// instruction with reversed operands, see `LowerContractionToSVEI8MMPattern.cpp`
316+
// for more details.
317+
248318
// CHECK-IR-LABEL: llvm.func @test_summla
249319
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
250320
func.func @test_summla() {
@@ -300,9 +370,20 @@ func.func @test_summla() {
300370
vector.print %u2 : vector<[4]xi32>
301371
vector.print %u3 : vector<[4]xi32>
302372

373+
// Deallocate the buffers.
374+
memref.dealloc %acc_mem : memref<4x?xi32>
375+
memref.dealloc %lhs_mem : memref<4x8xi8>
376+
memref.dealloc %rhs_mem : memref<?xi8>
377+
303378
return
304379
}
305380

381+
// Perform each test with SVE vector lengths 128 bits and 256 bits (i.e. VSCALEs
382+
// 1 and 2, respectively). The vector length is set via the `setArmVLBits`
383+
// function. The effect of setting a different vector length is that the tests
384+
// allocate and operate on different sized buffers (see `prepare<X>TestData`
385+
// functions).
386+
306387
func.func @main() {
307388
%c128 = arith.constant 128 : i32
308389
%c256 = arith.constant 256 : i32
@@ -373,3 +454,6 @@ func.func @main() {
373454

374455
return
375456
}
457+
458+
func.func private @setArmVLBits(%bits : i32)
459+
func.func private @printMemrefI32(%ptr : memref<*xi32>)

0 commit comments

Comments
 (0)