4444// Such a `vector.contract` is representative of the code we aim to generate
4545// by vectorisation of `linalg.mmt4d`.
4646//
47- // In this specific test we use M == 4, N == 4, and K == 16 .
47+ // In this specific test we use M == 4, N == 4, and K == 8 .
4848//
4949
5050// Test the operation where both LHS and RHS are interpreted as signed, hence
@@ -59,43 +59,43 @@ func.func @test_smmla() {
5959 %c0_i8 = arith.constant 0 : i8
6060
6161 // Accumulator test data
62- %acc_cst = arith.constant dense <[[ - 1 , - 9 , - 4 , 0 ],
63- [ 6 , 5 , 7 , 2 ],
64- [ - 8 , - 7 , 9 , - 10 ],
65- [ 9 , 4 , - 4 , 0 ]]> : vector <4 x4 xi32 >
62+ %acc_cst = arith.constant dense <[[- 44 , 20 , 44 , - 46 ],
63+ [ - 8 , 25 , - 34 , 26 ],
64+ [- 20 , - 36 , - 3 , 39 ],
65+ [- 48 , - 31 , - 25 , - 21 ]]> : vector <4 x4 xi32 >
6666
6767 %acc_mem = memref.alloca () : memref <4 x4 xi32 >
6868 vector.transfer_write %acc_cst , %acc_mem [%c0 , %c0 ] : vector <4 x4 xi32 >, memref <4 x4 xi32 >
6969 %acc = vector.transfer_read %acc_mem [%c0 , %c0 ], %c0_i32 {in_bounds = [true , true ]} : memref <4 x4 xi32 >, vector <4 x4 xi32 >
7070
7171 // LHS test data
72- %lhs_cst = arith.constant dense <[[ - 4 , - 4 , - 4 , - 6 , 0 , 1 , 6 , 2 , - 1 , 4 , 5 , -8 , 9 , 5 , 4 , 9 ],
73- [ - 1 , 6 , 0 , 7 , - 7 , 8 , 5 , 8 , -7 , 6 , - 2 , 1 , 1 , 5 , - 4 , - 4 ],
74- [ 4 , - 10 , 10 , - 3 , 5 , 3 , 2 , 3 , - 7 , 9 , - 9 , -10 , 7 , - 8 , - 5 , - 2 ],
75- [ 9 , 5 , 8 , 9 , 6 , -3 , - 9 , 7 , - 4 , - 7 , - 2 , 7 , - 8 , 2 , 8 , 7 ]]> : vector <4 x 16 x i8 >
72+ %lhs_cst = arith.constant dense <[[- 35 , - 27 , - 36 , - 31 , 23 , - 34 , -8 , - 33 ],
73+ [- 20 , 17 , - 32 , - 47 , 37 , 22 , -7 , - 21 ],
74+ [ - 7 , - 35 , 20 , - 4 , 39 , 46 , -23 , 40 ],
75+ [ 40 , 27 , 37 , 43 , 38 , -6 , 37 , 49 ]]> : vector <4 x 8 x i8 >
7676
77- %lhs_mem = memref.alloca () : memref <4 x 16 x i8 >
78- vector.transfer_write %lhs_cst , %lhs_mem [%c0 , %c0 ] : vector <4 x 16 x i8 >, memref <4 x 16 x i8 >
79- %lhs = vector.transfer_read %lhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 16 x i8 >, vector <4 x 16 x i8 >
77+ %lhs_mem = memref.alloca () : memref <4 x 8 x i8 >
78+ vector.transfer_write %lhs_cst , %lhs_mem [%c0 , %c0 ] : vector <4 x 8 x i8 >, memref <4 x 8 x i8 >
79+ %lhs = vector.transfer_read %lhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 8 x i8 >, vector <4 x 8 x i8 >
8080
8181 // RHS test data
82- %rhs_cst = arith.constant dense <[[ 1 , 2 , - 3 , 5 , 10 , 8 , 10 , - 2 , 1 , 10 , - 5 , 2 , 4 , 3 , - 9 , 4 ],
83- [ - 3 , - 3 , - 3 , 4 , 6 , - 1 , 0 , - 5 , 6 , 3 , - 1 , 9 , - 3 , 3 , - 2 , 4 ],
84- [ 1 , 9 , - 1 , 1 , - 5 , 4 , 9 , - 10 , - 1 , - 7 , 10 , - 2 , 0 , - 3 , 4 , 7 ],
85- [ - 4 , - 10 , 8 , - 10 , - 5 , - 8 , - 6 , 7 , 4 , - 2 , 10 , 3 , - 9 , 5 , 2 , - 1 ]]> : vector <4 x 16 x i8 >
82+ %rhs_cst = arith.constant dense <[[- 17 , - 50 , - 1 , 48 , - 13 , 22 , 39 , 33 ],
83+ [- 35 , - 24 , 37 , - 32 , 33 , 30 , - 11 , - 17 ],
84+ [- 28 , 31 , 3 , - 44 , - 15 , - 27 , 22 , 35 ],
85+ [- 23 , 39 , 48 , 26 , - 23 , 32 , - 39 , - 38 ]]> : vector <4 x 8 x i8 >
8686
87- %rhs_mem = memref.alloca () : memref <4 x 16 x i8 >
88- vector.transfer_write %rhs_cst , %rhs_mem [%c0 , %c0 ] : vector <4 x 16 x i8 >, memref <4 x 16 x i8 >
89- %rhs = vector.transfer_read %rhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 16 x i8 >, vector <4 x 16 x i8 >
87+ %rhs_mem = memref.alloca () : memref <4 x 8 x i8 >
88+ vector.transfer_write %rhs_cst , %rhs_mem [%c0 , %c0 ] : vector <4 x 8 x i8 >, memref <4 x 8 x i8 >
89+ %rhs = vector.transfer_read %rhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 8 x i8 >, vector <4 x 8 x i8 >
9090
9191
9292 // Matrix multiplication and accumulate with transposed RHS.
93- %0 = arith.extsi %lhs : vector <4 x 16 x i8 > to vector <4 x 16 x i32 >
94- %1 = arith.extsi %rhs : vector <4 x 16 x i8 > to vector <4 x 16 x i32 >
93+ %0 = arith.extsi %lhs : vector <4 x 8 x i8 > to vector <4 x 8 x i32 >
94+ %1 = arith.extsi %rhs : vector <4 x 8 x i8 > to vector <4 x 8 x i32 >
9595 %2 = vector.contract {index ing_maps = #packed_maps ,
9696 iterator_types = [" parallel" , " parallel" , " reduction" ],
9797 kind = #vector.kind <add >} %0 , %1 , %acc
98- : vector <4 x 16 x i32 >, vector <4 x 16 x i32 > into vector <4 x4 xi32 >
98+ : vector <4 x 8 x i32 >, vector <4 x 8 x i32 > into vector <4 x4 xi32 >
9999
100100 // Display the result of the multiplication
101101 vector.print str " Result(SMMLA):\n "
@@ -123,42 +123,42 @@ func.func @test_ummla() {
123123 %c0_i8 = arith.constant 0 : i8
124124
125125 // Accumulator test data
126- %acc_cst = arith.constant dense <[[39 , 39 , 46 , 30 ],
127- [22 , 48 , 61 , 54 ],
128- [41 , 63 , 27 , 10 ],
129- [37 , 30 , 16 , 45 ]]> : vector <4 x4 xi32 >
126+ %acc_cst = arith.constant dense <[[16 , 16 , 48 , 40 ],
127+ [40 , 24 , 35 , 12 ],
128+ [33 , 24 , 29 , 19 ],
129+ [28 , 13 , 33 , 18 ]]> : vector <4 x4 xi32 >
130130
131131 %acc_mem = memref.alloca () : memref <4 x4 xi32 >
132132 vector.transfer_write %acc_cst , %acc_mem [%c0 , %c0 ] : vector <4 x4 xi32 >, memref <4 x4 xi32 >
133133 %acc = vector.transfer_read %acc_mem [%c0 , %c0 ], %c0_i32 {in_bounds = [true , true ]} : memref <4 x4 xi32 >, vector <4 x4 xi32 >
134134
135135 // LHS test data
136- %lhs_cst = arith.constant dense <[[ 6 , 6 , 38 , 30 , 60 , 4 , 42 , 11 , 16 , 12 , 30 , 41 , 14 , 55 , 47 , 25 ],
137- [ 2 , 19 , 25 , 29 , 15 , 23 , 14 , 19 , 9 , 16 , 42 , 17 , 58 , 62 , 30 , 3 ],
138- [62 , 50 , 47 , 18 , 3 , 48 , 23 , 8 , 43 , 29 , 43 , 15 , 6 , 38 , 46 , 25 ],
139- [32 , 27 , 52 , 39 , 47 , 26 , 26 , 13 , 23 , 29 , 24 , 44 , 23 , 45 , 35 , 51 ]]> : vector <4 x 16 x i8 >
136+ %lhs_cst = arith.constant dense <[[35 , 42 , 37 , 49 , 36 , 36 , 23 , 33 ],
137+ [39 , 34 , 33 , 45 , 43 , 10 , 44 , 47 ],
138+ [18 , 35 , 29 , 25 , 36 , 33 , 28 , 29 ],
139+ [26 , 49 , 43 , 32 , 27 , 16 , 45 , 33 ]]> : vector <4 x 8 x i8 >
140140
141- %lhs_mem = memref.alloca () : memref <4 x 16 x i8 >
142- vector.transfer_write %lhs_cst , %lhs_mem [%c0 , %c0 ] : vector <4 x 16 x i8 >, memref <4 x 16 x i8 >
143- %lhs = vector.transfer_read %lhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 16 x i8 >, vector <4 x 16 x i8 >
141+ %lhs_mem = memref.alloca () : memref <4 x 8 x i8 >
142+ vector.transfer_write %lhs_cst , %lhs_mem [%c0 , %c0 ] : vector <4 x 8 x i8 >, memref <4 x 8 x i8 >
143+ %lhs = vector.transfer_read %lhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 8 x i8 >, vector <4 x 8 x i8 >
144144
145145 // RHS test data
146- %rhs_cst = arith.constant dense <[[33 , 0 , 49 , 34 , 37 , 8 , 25 , 19 , 15 , 26 , 23 , 18 , 19 , 16 , 39 , 33 ],
147- [22 , 17 , 53 , 58 , 6 , 35 , 54 , 23 , 8 , 53 , 21 , 27 , 49 , 25 , 34 , 12 ],
148- [27 , 18 , 53 , 53 , 49 , 11 , 12 , 39 , 62 , 47 , 59 , 29 , 20 , 18 , 52 , 25 ],
149- [27 , 40 , 11 , 52 , 37 , 60 , 29 , 44 , 46 , 25 , 13 , 33 , 14 , 53 , 56 , 39 ]]> : vector <4 x 16 x i8 >
146+ %rhs_cst = arith.constant dense <[[18 , 31 , 37 , 35 , 44 , 22 , 37 , 28 ],
147+ [21 , 22 , 49 , 39 , 30 , 28 , 35 , 37 ],
148+ [21 , 47 , 39 , 35 , 23 , 43 , 24 , 49 ],
149+ [49 , 49 , 40 , 32 , 37 , 20 , 47 , 40 ]]> : vector <4 x 8 x i8 >
150150
151- %rhs_mem = memref.alloca () : memref <4 x 16 x i8 >
152- vector.transfer_write %rhs_cst , %rhs_mem [%c0 , %c0 ] : vector <4 x 16 x i8 >, memref <4 x 16 x i8 >
153- %rhs = vector.transfer_read %rhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 16 x i8 >, vector <4 x 16 x i8 >
151+ %rhs_mem = memref.alloca () : memref <4 x 8 x i8 >
152+ vector.transfer_write %rhs_cst , %rhs_mem [%c0 , %c0 ] : vector <4 x 8 x i8 >, memref <4 x 8 x i8 >
153+ %rhs = vector.transfer_read %rhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 8 x i8 >, vector <4 x 8 x i8 >
154154
155155 // Matrix multiplication and accumulate with transposed RHS.
156- %0 = arith.extui %lhs : vector <4 x 16 x i8 > to vector <4 x 16 x i32 >
157- %1 = arith.extui %rhs : vector <4 x 16 x i8 > to vector <4 x 16 x i32 >
156+ %0 = arith.extui %lhs : vector <4 x 8 x i8 > to vector <4 x 8 x i32 >
157+ %1 = arith.extui %rhs : vector <4 x 8 x i8 > to vector <4 x 8 x i32 >
158158 %2 = vector.contract {index ing_maps = #packed_maps ,
159159 iterator_types = [" parallel" , " parallel" , " reduction" ],
160160 kind = #vector.kind <add >} %0 , %1 , %acc
161- : vector <4 x 16 x i32 >, vector <4 x 16 x i32 > into vector <4 x4 xi32 >
161+ : vector <4 x 8 x i32 >, vector <4 x 8 x i32 > into vector <4 x4 xi32 >
162162
163163 // Display the result of the multiplication
164164 vector.print str " Result(UMMLA):\n "
@@ -187,42 +187,42 @@ func.func @test_usmmla() {
187187 %c0_i8 = arith.constant 0 : i8
188188
189189 // Accumulator test data
190- %acc_cst = arith.constant dense <[[-50 , 22 , - 15 , 6 ],
191- [ 0 , - 46 , 32 , -59 ],
192- [-62 , -60 , - 38 , 17 ],
193- [-50 , 8 , -12 , 22 ]]> : vector <4 x4 xi32 >
190+ %acc_cst = arith.constant dense <[[-44 , 20 , 44 , - 46 ],
191+ [ - 8 , 25 , -34 , 26 ],
192+ [-20 , -36 , - 3 , 39 ],
193+ [-48 , - 31 , -25 , - 21 ]]> : vector <4 x4 xi32 >
194194
195195 %acc_mem = memref.alloca () : memref <4 x4 xi32 >
196196 vector.transfer_write %acc_cst , %acc_mem [%c0 , %c0 ] : vector <4 x4 xi32 >, memref <4 x4 xi32 >
197197 %acc = vector.transfer_read %acc_mem [%c0 , %c0 ], %c0_i32 {in_bounds = [true , true ]} : memref <4 x4 xi32 >, vector <4 x4 xi32 >
198198
199199 // LHS test data
200- %lhs_cst = arith.constant dense <[[ 6 , 6 , 38 , 30 , 60 , 4 , 42 , 11 , 16 , 12 , 30 , 41 , 14 , 55 , 47 , 25 ],
201- [ 2 , 19 , 25 , 29 , 15 , 23 , 14 , 19 , 9 , 16 , 42 , 17 , 58 , 62 , 30 , 3 ],
202- [62 , 50 , 47 , 18 , 3 , 48 , 23 , 8 , 43 , 29 , 43 , 15 , 6 , 38 , 46 , 25 ],
203- [32 , 27 , 52 , 39 , 47 , 26 , 26 , 13 , 23 , 29 , 24 , 44 , 23 , 45 , 35 , 51 ]]> : vector <4 x 16 x i8 >
200+ %lhs_cst = arith.constant dense <[[153 , 161 , 24 , 157 , 211 , 154 , 52 , 27 ],
201+ [168 , 77 , 136 , 124 , 249 , 28 , 13 , 122 ],
202+ [ 97 , 82 , 181 , 39 , 53 , 25 , 80 , 240 ],
203+ [184 , 227 , 106 , 165 , 126 , 113 , 121 , 228 ]]> : vector <4 x 8 x i8 >
204204
205- %lhs_mem = memref.alloca () : memref <4 x 16 x i8 >
206- vector.transfer_write %lhs_cst , %lhs_mem [%c0 , %c0 ] : vector <4 x 16 x i8 >, memref <4 x 16 x i8 >
207- %lhs = vector.transfer_read %lhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 16 x i8 >, vector <4 x 16 x i8 >
205+ %lhs_mem = memref.alloca () : memref <4 x 8 x i8 >
206+ vector.transfer_write %lhs_cst , %lhs_mem [%c0 , %c0 ] : vector <4 x 8 x i8 >, memref <4 x 8 x i8 >
207+ %lhs = vector.transfer_read %lhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 8 x i8 >, vector <4 x 8 x i8 >
208208
209209 // RHS test data
210- %rhs_cst = arith.constant dense <[[ - 9 , - 10 , 7 , - 8 , - 5 , - 2 , 9 , 5 , 8 , 9 , 6 , - 3 , - 9 , 7 , - 4 , - 7 ],
211- [ - 2 , 7 , - 8 , 2 , 8 , 7 , 1 , 2 , - 3 , 5 , 8 , - 2 , 1 , - 5 , 2 , 4 ],
212- [ 3 , - 9 , 4 , - 3 , - 3 , - 3 , 4 , 6 , - 1 , 0 , - 5 , 6 , 3 , - 1 , 9 , - 3 ],
213- [ 3 , - 2 , 4 , 1 , 9 , - 1 , 1 , - 5 , 4 , 9 , - 10 , - 1 , - 7 , - 2 , 0 , - 3 ]]> : vector <4 x 16 x i8 >
210+ %rhs_cst = arith.constant dense <[[ 40 , 27 , 37 , 43 , 38 , - 6 , 37 , 49 ],
211+ [- 17 , - 50 , - 1 , 48 , - 13 , 22 , 39 , 33 ],
212+ [- 35 , - 24 , 37 , - 32 , 33 , 30 , - 11 , - 17 ],
213+ [- 28 , 31 , 3 , - 44 , - 15 , - 27 , 22 , 35 ]]> : vector <4 x 8 x i8 >
214214
215- %rhs_mem = memref.alloca () : memref <4 x 16 x i8 >
216- vector.transfer_write %rhs_cst , %rhs_mem [%c0 , %c0 ] : vector <4 x 16 x i8 >, memref <4 x 16 x i8 >
217- %rhs = vector.transfer_read %rhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 16 x i8 >, vector <4 x 16 x i8 >
215+ %rhs_mem = memref.alloca () : memref <4 x 8 x i8 >
216+ vector.transfer_write %rhs_cst , %rhs_mem [%c0 , %c0 ] : vector <4 x 8 x i8 >, memref <4 x 8 x i8 >
217+ %rhs = vector.transfer_read %rhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 8 x i8 >, vector <4 x 8 x i8 >
218218
219219 // Matrix multiplication and accumulate with transposed RHS.
220- %0 = arith.extui %lhs : vector <4 x 16 x i8 > to vector <4 x 16 x i32 >
221- %1 = arith.extsi %rhs : vector <4 x 16 x i8 > to vector <4 x 16 x i32 >
220+ %0 = arith.extui %lhs : vector <4 x 8 x i8 > to vector <4 x 8 x i32 >
221+ %1 = arith.extsi %rhs : vector <4 x 8 x i8 > to vector <4 x 8 x i32 >
222222 %2 = vector.contract {index ing_maps = #packed_maps ,
223223 iterator_types = [" parallel" , " parallel" , " reduction" ],
224224 kind = #vector.kind <add >} %0 , %1 , %acc
225- : vector <4 x 16 x i32 >, vector <4 x 16 x i32 > into vector <4 x4 xi32 >
225+ : vector <4 x 8 x i32 >, vector <4 x 8 x i32 > into vector <4 x4 xi32 >
226226
227227 // Display the result of the multiplication
228228 vector.print str " Result(USMMLA):\n "
@@ -252,42 +252,42 @@ func.func @test_summla() {
252252 %c0_i8 = arith.constant 0 : i8
253253
254254 // Accumulator test data
255- %acc_cst = arith.constant dense <[[-61 , 52 , 8 , -54 ],
256- [- 25 , - 50 , 22 , -15 ],
257- [ 6 , 0 , - 46 , 32 ],
258- [-59 , -62 , -60 , -38 ]]> : vector <4 x4 xi32 >
255+ %acc_cst = arith.constant dense <[[-44 , 20 , 44 , -46 ],
256+ [ - 8 , 25 , -34 , 26 ],
257+ [- 20 , - 36 , - 3 , 39 ],
258+ [-48 , -31 , -25 , -21 ]]> : vector <4 x4 xi32 >
259259
260260 %acc_mem = memref.alloca () : memref <4 x4 xi32 >
261261 vector.transfer_write %acc_cst , %acc_mem [%c0 , %c0 ] : vector <4 x4 xi32 >, memref <4 x4 xi32 >
262262 %acc = vector.transfer_read %acc_mem [%c0 , %c0 ], %c0_i32 {in_bounds = [true , true ]} : memref <4 x4 xi32 >, vector <4 x4 xi32 >
263263
264264 // LHS test data
265- %lhs_cst = arith.constant dense <[[ - 4 , - 4 , - 4 , - 6 , 0 , 1 , 6 , 2 , - 1 , 4 , 5 , -8 , 9 , 5 , 4 , 9 ],
266- [ - 1 , 6 , 0 , 7 , - 7 , 8 , 5 , 8 , -7 , 6 , - 2 , 1 , 1 , 5 , - 4 , - 4 ],
267- [ 4 , - 10 , - 3 , 5 , 3 , 2 , 3 , - 7 , 9 , - 9 , -10 , 7 , - 8 , - 5 , - 2 , 9 ],
268- [ 5 , 8 , 9 , 6 , - 3 , -9 , 7 , - 4 , - 7 , - 2 , 7 , - 8 , 2 , 8 , 7 , 1 ]]> : vector <4 x 16 x i8 >
265+ %lhs_cst = arith.constant dense <[[- 35 , - 27 , - 36 , - 31 , 23 , - 34 , -8 , - 33 ],
266+ [- 20 , 17 , - 32 , - 47 , 37 , 22 , -7 , - 21 ],
267+ [ - 7 , - 35 , 20 , - 4 , 39 , 46 , -23 , 40 ],
268+ [ 40 , 27 , 37 , 43 , 38 , -6 , 37 , 49 ]]> : vector <4 x 8 x i8 >
269269
270- %lhs_mem = memref.alloca () : memref <4 x 16 x i8 >
271- vector.transfer_write %lhs_cst , %lhs_mem [%c0 , %c0 ] : vector <4 x 16 x i8 >, memref <4 x 16 x i8 >
272- %lhs = vector.transfer_read %lhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 16 x i8 >, vector <4 x 16 x i8 >
270+ %lhs_mem = memref.alloca () : memref <4 x 8 x i8 >
271+ vector.transfer_write %lhs_cst , %lhs_mem [%c0 , %c0 ] : vector <4 x 8 x i8 >, memref <4 x 8 x i8 >
272+ %lhs = vector.transfer_read %lhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 8 x i8 >, vector <4 x 8 x i8 >
273273
274274 // RHS test data
275- %rhs_cst = arith.constant dense <[[12 , 39 , 62 , 47 , 59 , 29 , 20 , 18 , 52 , 25 , 27 , 40 , 11 , 52 , 37 , 60 ],
276- [29 , 44 , 46 , 25 , 13 , 33 , 14 , 53 , 56 , 39 , 39 , 39 , 46 , 30 , 22 , 48 ],
277- [61 , 54 , 41 , 63 , 27 , 10 , 37 , 30 , 16 , 45 , 41 , 51 , 39 , 28 , 13 , 28 ],
278- [21 , 28 , 24 , 40 , 46 , 30 , 11 , 19 , 9 , 11 , 5 , 46 , 19 , 26 , 0 , 9 ]]> : vector <4 x 16 x i8 >
275+ %rhs_cst = arith.constant dense <[[125 , 171 , 138 , 187 , 108 , 175 , 82 , 99 ],
276+ [221 , 25 , 164 , 97 , 156 , 221 , 218 , 177 ],
277+ [171 , 160 , 219 , 191 , 144 , 45 , 161 , 210 ],
278+ [223 , 165 , 123 , 99 , 108 , 86 , 37 , 92 ]]> : vector <4 x 8 x i8 >
279279
280- %rhs_mem = memref.alloca () : memref <4 x 16 x i8 >
281- vector.transfer_write %rhs_cst , %rhs_mem [%c0 , %c0 ] : vector <4 x 16 x i8 >, memref <4 x 16 x i8 >
282- %rhs = vector.transfer_read %rhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 16 x i8 >, vector <4 x 16 x i8 >
280+ %rhs_mem = memref.alloca () : memref <4 x 8 x i8 >
281+ vector.transfer_write %rhs_cst , %rhs_mem [%c0 , %c0 ] : vector <4 x 8 x i8 >, memref <4 x 8 x i8 >
282+ %rhs = vector.transfer_read %rhs_mem [%c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <4 x 8 x i8 >, vector <4 x 8 x i8 >
283283
284284 // Matrix multiplication and accumulate with transposed RHS.
285- %0 = arith.extsi %lhs : vector <4 x 16 x i8 > to vector <4 x 16 x i32 >
286- %1 = arith.extui %rhs : vector <4 x 16 x i8 > to vector <4 x 16 x i32 >
285+ %0 = arith.extsi %lhs : vector <4 x 8 x i8 > to vector <4 x 8 x i32 >
286+ %1 = arith.extui %rhs : vector <4 x 8 x i8 > to vector <4 x 8 x i32 >
287287 %2 = vector.contract {index ing_maps = #packed_maps ,
288288 iterator_types = [" parallel" , " parallel" , " reduction" ],
289289 kind = #vector.kind <add >} %0 , %1 , %acc
290- : vector <4 x 16 x i32 >, vector <4 x 16 x i32 > into vector <4 x4 xi32 >
290+ : vector <4 x 8 x i32 >, vector <4 x 8 x i32 > into vector <4 x4 xi32 >
291291
292292 // Display the result of the multiplication
293293 vector.print str " Result(SUMMLA (i.e. USMMLA transposed)):\n "
@@ -305,31 +305,31 @@ func.func @test_summla() {
305305
306306func.func @main () {
307307// CHECK-LABEL: Result(SMMLA):
308- // CHECK: ( 82 , -63 , 95, 11 )
309- // CHECK: ( 184 , -81 , -17, -172 )
310- // CHECK: ( 168, -158 , -251, -133 )
311- // CHECK: ( -139, 40, -48 , 75 )
308+ // CHECK: ( -1999 , 1941 , 685, -2879 )
309+ // CHECK: ( -3705 , 2952 , 987, -685 )
310+ // CHECK: ( 2565, 4157 , -1589, -357 )
311+ // CHECK: ( 2383, -2252 , 32, -1365 )
312312 func.call @test_smmla () : () -> ()
313313
314314// CHECK-LABEL: Result(UMMLA):
315- // CHECK: ( 12414, 13508, 16691, 16069 )
316- // CHECK: ( 8935, 13219, 13408, 13644 )
317- // CHECK: ( 12223, 15233, 18131, 18553 )
318- // CHECK: ( 14459, 16573, 19443, 19417 )
315+ // CHECK: ( 9183, 9513, 10460, 11314 )
316+ // CHECK: ( 9648, 9812, 10092, 12088 )
317+ // CHECK: ( 7548, 7625, 8398, 9044 )
318+ // CHECK: ( 8855, 9046, 9685, 11191 )
319319 func.call @test_ummla () : () -> ()
320320
321321// CHECK-LABEL: Result(USMMLA):
322- // CHECK: ( 176 , 483, 468 , 265 )
323- // CHECK: ( 23 , 449, 192, -727 )
324- // CHECK: ( -128 , 563 , -30 , 66 )
325- // CHECK: ( -476 , 657, 202 , 334 )
322+ // CHECK: ( 28403 , 445 , -2759, -11409 )
323+ // CHECK: ( 34908, 1047 , 142, -7274 )
324+ // CHECK: ( 31032 , 6807 , -2378 , 7382 )
325+ // CHECK: ( 44217 , 6396, -10930 , 623 )
326326 func.call @test_usmmla () : () -> ()
327327
328328// CHECK-LABEL: Result(SUMMLA (i.e. USMMLA transposed)):
329- // CHECK: ( 300, 716, 54 , -378 )
330- // CHECK: ( 244 , 746, 1184 , 689 )
331- // CHECK: ( 253, -655, -688 , 115 )
332- // CHECK: ( 995 , 574, 1490 , 177 )
329+ // CHECK: ( -27190, -28812, -30502 , -23575 )
330+ // CHECK: ( -7613 , -8386, -15938 , -6521 )
331+ // CHECK: ( 9468, 18750, 9199 , 5764 )
332+ // CHECK: ( 33655 , 41064, 48900 , 31627 )
333333 func.call @test_summla () : () -> ()
334334
335335 return
0 commit comments