1010// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
1111// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
1212
13- // RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
13+ // RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
1414
1515#packed_maps = [
1616 affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>,
2020
2121func.func private @setArmVLBits (%bits : i32 )
2222
23+ func.func private @prepareAccTestData (%in: vector <4 x4 xi32 >) -> vector <4 x[4 ]xi32 > {
24+ %c0 = arith.constant 0 : index
25+ %c0_i32 = arith.constant 0 : i32
26+
27+ %mem = memref.alloca () : memref <4 x4 xi32 >
28+ vector.transfer_write %in , %mem [%c0 , %c0 ] : vector <4 x4 xi32 >, memref <4 x4 xi32 >
29+
30+ %flat_mem = memref.collapse_shape %mem [[0 , 1 ]] : memref <4 x4 xi32 > into memref <16 xi32 >
31+ %flat_vec = vector.transfer_read %flat_mem [%c0 ], %c0_i32 {in_bounds = [true ]} : memref <16 xi32 >, vector <[16 ]xi32 >
32+ %out = vector.shape_cast %flat_vec : vector <[16 ]xi32 > to vector <4 x[4 ]xi32 >
33+
34+ return %out : vector <4 x[4 ]xi32 >
35+ }
36+
37+ func.func private @prepareLHSTestData (%in: vector <4 x8 xi8 >) -> vector <4 x8 xi8 > {
38+ %c0 = arith.constant 0 : index
39+ %c0_i8 = arith.constant 0 : i8
40+
41+ %mem = memref.alloca () : memref <4 x8 xi8 >
42+ vector.transfer_write %in , %mem [%c0 , %c0 ] : vector <4 x8 xi8 >, memref <4 x8 xi8 >
43+
44+ %out = vector.transfer_read %mem [%c0 , %c0 ], %c0_i8 : memref <4 x8 xi8 >, vector <4 x8 xi8 >
45+
46+ return %out : vector <4 x8 xi8 >
47+ }
48+
49+ func.func private @prepareRHSTestData (%in: vector <4 x8 xi8 >) -> vector <[32 ]xi8 > {
50+ %c0 = arith.constant 0 : index
51+ %c0_i8 = arith.constant 0 : i8
52+
53+ %mem = memref.alloca () : memref <4 x8 xi8 >
54+ vector.transfer_write %in , %mem [%c0 , %c0 ] : vector <4 x8 xi8 >, memref <4 x8 xi8 >
55+
56+ %flat_mem = memref.collapse_shape %mem [[0 , 1 ]] : memref <4 x8 xi8 > into memref <32 xi8 >
57+ %flat_vec = vector.transfer_read %flat_mem [%c0 ], %c0_i8 {in_bounds = [true ]} : memref <32 xi8 >, vector <[32 ]xi8 >
58+
59+ return %flat_vec : vector <[32 ]xi8 >
60+ }
61+
2362func.func @main () {
2463 %c128 = arith.constant 128 : i32
2564 func.call @setArmVLBits (%c128 ) : (i32 ) -> ()
@@ -28,68 +67,32 @@ func.func @main() {
2867 %c0_i32 = arith.constant 0 : i32
2968 %c0_i8 = arith.constant 0 : i8
3069
31- // Accumulator test data
70+ // Accumulator test data
3271 %acc_cst = arith.constant dense <[[-44 , 20 , 44 , -46 ],
3372 [ -8 , 25 , -34 , 26 ],
3473 [-20 , -36 , -3 , 39 ],
3574 [-48 , -31 , -25 , -21 ]]> : vector <4 x4 xi32 >
36- %acc_m = memref.alloca () : memref <4 x4 xi32 >
37- vector.transfer_write %acc_cst , %acc_m [%c0 , %c0 ] : vector <4 x4 xi32 >, memref <4 x4 xi32 >
38-
39- %acc_m1 = memref.collapse_shape %acc_m [[0 , 1 ]] : memref <4 x4 xi32 > into memref <16 xi32 >
40- %acc_flat = vector.transfer_read %acc_m1 [%c0 ], %c0_i32 {in_bounds = [true ]} : memref <16 xi32 >, vector <[16 ]xi32 >
41- %acc = vector.shape_cast %acc_flat : vector <[16 ]xi32 > to vector <4 x[4 ]xi32 >
42-
43- vector.print str " ACC:\n "
44- %acc0 = vector.extract %acc [0 ] : vector <[4 ]xi32 > from vector <4 x[4 ]xi32 >
45- %acc1 = vector.extract %acc [1 ] : vector <[4 ]xi32 > from vector <4 x[4 ]xi32 >
46- %acc2 = vector.extract %acc [2 ] : vector <[4 ]xi32 > from vector <4 x[4 ]xi32 >
47- %acc3 = vector.extract %acc [3 ] : vector <[4 ]xi32 > from vector <4 x[4 ]xi32 >
48- vector.print %acc0 : vector <[4 ]xi32 >
49- vector.print %acc1 : vector <[4 ]xi32 >
50- vector.print %acc2 : vector <[4 ]xi32 >
51- vector.print %acc3 : vector <[4 ]xi32 >
75+
76+ %acc = func.call @prepareAccTestData (%acc_cst ) : (vector <4 x4 xi32 >) -> vector <4 x[4 ]xi32 >
5277
5378 // LHS test data
5479 %lhs_cst = arith.constant dense <[[-35 , -27 , -36 , -31 , 23 , -34 , -8 , -33 ],
55- [-20 , 17 , -32 , -47 , 37 , 22 , -7 , -21 ],
56- [ -7 , -35 , 20 , -4 , 39 , 46 , -23 , 40 ],
57- [ 40 , 27 , 37 , 43 , 38 , -6 , 37 , 49 ]]> : vector <4 x8 xi8 >
58-
59- %lhs_m = memref.alloca () : memref <4 x8 xi8 >
60- vector.transfer_write %lhs_cst , %lhs_m [%c0 , %c0 ] : vector <4 x8 xi8 >, memref <4 x8 xi8 >
61- %lhs = vector.transfer_read %lhs_m [%c0 , %c0 ], %c0_i8 : memref <4 x8 xi8 >, vector <4 x8 xi8 >
62-
63- vector.print str " LHS:\n "
64- %lhs0 = vector.extract %lhs [0 ] : vector <8 xi8 > from vector <4 x8 xi8 >
65- %lhs1 = vector.extract %lhs [1 ] : vector <8 xi8 > from vector <4 x8 xi8 >
66- %lhs2 = vector.extract %lhs [2 ] : vector <8 xi8 > from vector <4 x8 xi8 >
67- %lhs3 = vector.extract %lhs [3 ] : vector <8 xi8 > from vector <4 x8 xi8 >
68- vector.print %lhs0 : vector <8 xi8 >
69- vector.print %lhs1 : vector <8 xi8 >
70- vector.print %lhs2 : vector <8 xi8 >
71- vector.print %lhs3 : vector <8 xi8 >
80+ [-20 , 17 , -32 , -47 , 37 , 22 , -7 , -21 ],
81+ [ -7 , -35 , 20 , -4 , 39 , 46 , -23 , 40 ],
82+ [ 40 , 27 , 37 , 43 , 38 , -6 , 37 , 49 ]]> : vector <4 x8 xi8 >
83+
84+ %lhs = func.call @prepareLHSTestData (%lhs_cst ) : (vector <4 x8 xi8 >) -> vector <4 x8 xi8 >
7285
7386 // RHS test data
7487 %rhs_cst = arith.constant dense <[[-17 , -50 , -1 , 48 , -13 , 22 , 39 , 33 ],
7588 [-35 , -24 , 37 , -32 , 33 , 30 , -11 , -17 ],
7689 [-28 , 31 , 3 , -44 , -15 , -27 , 22 , 35 ],
7790 [-23 , 39 , 48 , 26 , -23 , 32 , -39 , -38 ]]> : vector <4 x8 xi8 >
78-
79- %rhs_m = memref.alloca () : memref <4 x8 xi8 >
80- vector.transfer_write %rhs_cst , %rhs_m [%c0 , %c0 ] : vector <4 x8 xi8 >, memref <4 x8 xi8 >
81-
82- %rhs_m1 = memref.collapse_shape %rhs_m [[0 , 1 ]] : memref <4 x8 xi8 > into memref <32 xi8 >
83- %rhs_flat = vector.transfer_read %rhs_m1 [%c0 ], %c0_i8 {in_bounds = [true ]} : memref <32 xi8 >, vector <[32 ]xi8 >
84-
85- vector.print str " RHS:\n "
86- %rhs0 = vector.scalable.extract %rhs_flat [0 ] : vector <[16 ]xi8 > from vector <[32 ]xi8 >
87- %rhs1 = vector.scalable.extract %rhs_flat [16 ] : vector <[16 ]xi8 > from vector <[32 ]xi8 >
88- vector.print %rhs0 : vector <[16 ]xi8 >
89- vector.print %rhs1 : vector <[16 ]xi8 >
90-
91+ %rhs_flat = func.call @prepareRHSTestData (%rhs_cst ) : (vector <4 x8 xi8 >) -> vector <[32 ]xi8 >
9192 %rhs = vector.shape_cast %rhs_flat : vector <[32 ]xi8 > to vector <[4 ]x8 xi8 >
9293
94+ // CHECK-IR-COUNT-4: arm_sve.intr.smmla
95+
9396 // Matrix multiplication
9497 %0 = arith.extsi %lhs : vector <4 x8 xi8 > to vector <4 x8 xi32 >
9598 %1 = arith.extsi %rhs : vector <[4 ]x8 xi8 > to vector <[4 ]x8 xi32 >
0 commit comments