1919 affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
2020]
2121
22- func.func private @setArmVLBits (%bits : i32 )
23- func.func private @printMemrefI32 (%ptr : memref <*xi32 >)
24-
22+ //
23+ // Test the lowering of `vector.contract` using the `LowerContractionToSVEI8MMPattern`
24+ //
25+ // The operation that the `vector.contract` in this test performs is matrix
26+ // multiplication with accumulate
27+ // OUT = ACC + LHS * RHS
28+ // of two 8-bit integer matrices LHS and RHS, and a 32-bit integer matrix ACC
29+ // into a 32-bit integer matrix OUT. The LHS and RHS can be sign- or zero- extended,
30+ // this test covers all the possible variants.
31+ //
32+ // Tested are calculations as well as that the relevant `ArmSVE` dialect
33+ // operations ('arm_sve.smmla`, arm_sve.ummla`, etc) are emitted.
34+ //
35+ // That pattern above handles (therefore this test prepares) input/output vectors with
36+ // specific shapes:
37+ // * LHS: vector<Mx8xi8>
38+ // * RHS: vector<[N]x8xi8>
39+ // * ACC, OUT: vector<Mx[N]xi32>
40+ // Note that the RHS is transposed.
41+ // See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
42+ // for more information and rationale about these shapes.
43+ //
44+ // In this specific test we use M == 4 and N == 4
45+ //
46+
47+ // Allocate and initialise a memref containing test data for use as the ACC
48+ // operand. The memref has one dynamic dimension whose extent depends on the
49+ // runtime value of VSCALE.
50+ //
51+ // The input parameter `%in` is a vector that is replicated VSCALE times
52+ // across the columns of the memref.
2553func.func private @prepareAccTestData (%in: vector <4 x4 xi32 >) -> memref <4 x?xi32 > {
2654 %c0 = arith.constant 0 : index
2755 %c1 = arith.constant 1 : index
@@ -39,6 +67,9 @@ func.func private @prepareAccTestData(%in: vector<4x4xi32>) -> memref<4x?xi32> {
3967 return %mem : memref <4 x?xi32 >
4068}
4169
70+ // Allocate and initialise a memref containing test data for use as the LHS
71+ // operand. This function just writes the parameter `%in` into the memref.
72+ // The size of the LHS does not depends on VSCALE.
4273func.func private @prepareLHSTestData (%in: vector <4 x8 xi8 >) -> memref <4 x8 xi8 > {
4374 %c0 = arith.constant 0 : index
4475 %c0_i8 = arith.constant 0 : i8
@@ -49,6 +80,15 @@ func.func private @prepareLHSTestData(%in: vector<4x8xi8>) -> memref<4x8xi8> {
4980 return %mem : memref <4 x8 xi8 >
5081}
5182
83+ // Allocate and initialise a memref containing test data for use as the RHS
84+ // operand. The memref has one dynamic dimension whose extent depends on the
85+ // runtime value of VSCALE.
86+ //
87+ // The input parameter `%in` is a vector that is replicated VSCALE times
88+ // across the rows of the memref.
89+ //
90+ // For convenience, flatten the memref, since the RHS vector is read first as a
91+ // single-dimensional scalable vector and then cast into [N]x8 shape.
5292func.func private @prepareRHSTestData (%in: vector <4 x8 xi8 >) -> memref <?xi8 > {
5393 %c0 = arith.constant 0 : index
5494 %c1 = arith.constant 1 : index
@@ -67,6 +107,9 @@ func.func private @prepareRHSTestData(%in: vector<4x8xi8>) -> memref<?xi8> {
67107 return %mem_out : memref <?xi8 >
68108}
69109
110+ // Test the operation where both LHS and RHS are interpreted as signed, hence
111+ // we ultimately emit and execute the `smmla` instruction.
112+
70113// CHECK-IR-LABEL: llvm.func @test_smmla
71114// CHECK-IR-COUNT-4: arm_sve.intr.smmla
72115func.func @test_smmla () {
@@ -84,7 +127,7 @@ func.func @test_smmla() {
84127 %acc_mem = func.call @prepareAccTestData (%acc_cst ) : (vector <4 x4 xi32 >) -> memref <4 x?xi32 >
85128 %acc = vector.transfer_read %acc_mem [%c0 , %c0 ], %c0_i32 {in_bounds = [true , true ]} : memref <4 x?xi32 >, vector <4 x[4 ]xi32 >
86129
87- // Workaround for a crash, see https://github.com/llvm/llvm-project/issues/143670
130+ // FIXME: Workaround for a crash, see https://github.com/llvm/llvm-project/issues/143670
88131 %acc_cast = memref.cast %acc_mem : memref <4 x?xi32 > to memref <*xi32 >
89132 call @printMemrefI32 (%acc_cast ) : (memref <*xi32 >) -> ()
90133
@@ -126,9 +169,17 @@ func.func @test_smmla() {
126169 vector.print %u2 : vector <[4 ]xi32 >
127170 vector.print %u3 : vector <[4 ]xi32 >
128171
172+ // Deallocate the buffers.
173+ memref.dealloc %acc_mem : memref <4 x?xi32 >
174+ memref.dealloc %lhs_mem : memref <4 x8 xi8 >
175+ memref.dealloc %rhs_mem : memref <?xi8 >
176+
129177 return
130178}
131179
180+ // Test the operation where both LHS and RHS are interpreted as unsigned, hence
181+ // we ultimately emit and execute the `ummla` instruction.
182+
132183// CHECK-IR-LABEL: llvm.func @test_ummla
133184// CHECK-IR-COUNT-4: arm_sve.intr.ummla
134185func.func @test_ummla () {
@@ -184,9 +235,18 @@ func.func @test_ummla() {
184235 vector.print %u2 : vector <[4 ]xi32 >
185236 vector.print %u3 : vector <[4 ]xi32 >
186237
238+ // Deallocate the buffers.
239+ memref.dealloc %acc_mem : memref <4 x?xi32 >
240+ memref.dealloc %lhs_mem : memref <4 x8 xi8 >
241+ memref.dealloc %rhs_mem : memref <?xi8 >
242+
187243 return
188244}
189245
246+ // Test the operation where LHS is interpreted as unsigned and RHS is
247+ // interpreted as signed, hence we ultimately emit and execute the `usmmla`
248+ // instruction.
249+
190250// CHECK-IR-LABEL: llvm.func @test_usmmla
191251// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
192252func.func @test_usmmla () {
@@ -242,9 +302,19 @@ func.func @test_usmmla() {
242302 vector.print %u2 : vector <[4 ]xi32 >
243303 vector.print %u3 : vector <[4 ]xi32 >
244304
305+ // Deallocate the buffers.
306+ memref.dealloc %acc_mem : memref <4 x?xi32 >
307+ memref.dealloc %lhs_mem : memref <4 x8 xi8 >
308+ memref.dealloc %rhs_mem : memref <?xi8 >
309+
245310 return
246311}
247312
313+ // Test the operation where LHS is interpreted as signed and RHS is interpreted
314+ // as unsigned. In this test we ultimately emit end execute the `usmmla`
315+ // instruction with reversed operands, see `LowerContractionToSVEI8MMPattern.cpp`
316+ // for more details.
317+
248318// CHECK-IR-LABEL: llvm.func @test_summla
249319// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
250320func.func @test_summla () {
@@ -300,9 +370,20 @@ func.func @test_summla() {
300370 vector.print %u2 : vector <[4 ]xi32 >
301371 vector.print %u3 : vector <[4 ]xi32 >
302372
373+ // Deallocate the buffers.
374+ memref.dealloc %acc_mem : memref <4 x?xi32 >
375+ memref.dealloc %lhs_mem : memref <4 x8 xi8 >
376+ memref.dealloc %rhs_mem : memref <?xi8 >
377+
303378 return
304379}
305380
381+ // Perform each test with SVE vector lengths 128 bits and 256 bits (i.e. VSCALEs
382+ // 1 and 2, respectively). The vector length is set via the `setArmVLBits`
383+ // function. The effect of setting a different vector length is that the tests
384+ // allocate and operate on different sized buffers (see `prepare<X>TestData`
385+ // functions).
386+
306387func.func @main () {
307388 %c128 = arith.constant 128 : i32
308389 %c256 = arith.constant 256 : i32
@@ -373,3 +454,6 @@ func.func @main() {
373454
374455 return
375456}
457+
458+ func.func private @setArmVLBits (%bits : i32 )
459+ func.func private @printMemrefI32 (%ptr : memref <*xi32 >)
0 commit comments