@@ -169,6 +169,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
169169 return %0 : vector <2 x3 xf32 >
170170}
171171
172+ // CHECK-LABEL: func @matmul_scalable
173+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
174+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
175+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
176+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
177+ // CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
178+ //
179+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
180+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
181+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
182+ // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
183+ //
184+ // CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
185+ // CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
186+ // CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
187+ // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
188+ //
189+ // CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
190+ // CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
191+ // CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
192+ // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
193+ //
194+ // CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
195+ // CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
196+ // CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
197+ // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
198+ //
199+ // CHECK: return %[[c3]] : vector<2x[3]xf32>
200+ func.func @matmul_scalable (%arg0: vector <2 x4 xf32 >,
201+ %arg1: vector <4 x[3 ]xf32 >,
202+ %arg2: vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 > {
203+ %0 = vector.contract #matmat_trait %arg0 , %arg1 , %arg2
204+ : vector <2 x4 xf32 >, vector <4 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
205+ return %0 : vector <2 x[3 ]xf32 >
206+ }
207+
172208// CHECK-LABEL: func @matmul_0
173209// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
174210// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -186,6 +222,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
186222 return %0 : vector <2 x3 xf32 >
187223}
188224
225+ // CHECK-LABEL: func @matmul_0_scalable
226+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
227+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
228+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
229+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
230+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
231+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
232+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
233+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
234+ func.func @matmul_0_scalable (%arg0: vector <2 x1 xf32 >, %arg1: vector <1 x[3 ]xf32 >, %arg2: vector <2 x[3 ]xf32 >)
235+ -> vector <2 x[3 ]xf32 >
236+ {
237+ %0 = vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
238+ : vector <2 x1 xf32 >, vector <1 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
239+ return %0 : vector <2 x[3 ]xf32 >
240+ }
241+
189242// CHECK-LABEL: func @matmul_0_mixed
190243// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
191244// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -205,6 +258,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
205258 return %0 : vector <2 x3 xf32 >
206259}
207260
261+ // CHECK-LABEL: func @matmul_0_mixed_scalable
262+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
263+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
264+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
265+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
266+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
267+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
268+ // CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
269+ // CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
270+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
271+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
272+ func.func @matmul_0_mixed_scalable (%arg0: vector <2 x1 xf16 >, %arg1: vector <1 x[3 ]xf16 >, %arg2: vector <2 x[3 ]xf32 >)
273+ -> vector <2 x[3 ]xf32 >
274+ {
275+ %0 = vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
276+ : vector <2 x1 xf16 >, vector <1 x[3 ]xf16 > into vector <2 x[3 ]xf32 >
277+ return %0 : vector <2 x[3 ]xf32 >
278+ }
279+
208280#matmat_accesses_1 = [
209281 affine_map <(m , n , k ) -> (m , k )>,
210282 affine_map <(m , n , k ) -> (n , k )>,
@@ -233,6 +305,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
233305 return %0 : vector <2 x3 xf32 >
234306}
235307
308+ // CHECK-LABEL: func @matmul_1_scalable
309+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
310+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
311+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
312+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
313+ // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
314+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
315+ // CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
316+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
317+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
318+ func.func @matmul_1_scalable (%arg0: vector <2 x1 xf32 >, %arg1: vector <[3 ]x1 xf32 >, %arg2: vector <2 x[3 ]xf32 >)
319+ -> vector <2 x[3 ]xf32 >
320+ {
321+ %0 = vector.contract #matmat_trait_1 %arg0 , %arg1 , %arg2
322+ : vector <2 x1 xf32 >, vector <[3 ]x1 xf32 > into vector <2 x[3 ]xf32 >
323+ return %0 : vector <2 x[3 ]xf32 >
324+ }
325+
236326#matmat_accesses_2 = [
237327 affine_map <(m , n , k ) -> (k , m )>,
238328 affine_map <(m , n , k ) -> (k , n )>,
@@ -259,6 +349,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
259349 return %0 : vector <2 x3 xf32 >
260350}
261351
352+ // CHECK-LABEL: func @matmul_2_scalable
353+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
354+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
355+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
356+ // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
357+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
358+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
359+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
360+ func.func @matmul_2_scalable (%arg0: vector <1 x2 xf32 >, %arg1: vector <1 x[3 ]xf32 >, %arg2: vector <2 x[3 ]xf32 >)
361+ -> vector <2 x[3 ]xf32 >
362+ {
363+ %0 = vector.contract #matmat_trait_2 %arg0 , %arg1 , %arg2
364+ : vector <1 x2 xf32 >, vector <1 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
365+ return %0 : vector <2 x[3 ]xf32 >
366+ }
367+
262368#matmat_accesses_3 = [
263369 affine_map <(m , n , k ) -> (k , m )>,
264370 affine_map <(m , n , k ) -> (n , k )>,
@@ -286,6 +392,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
286392 return %0 : vector <2 x3 xf32 >
287393}
288394
395+ // CHECK-LABEL: func @matmul_3_scalable
396+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
397+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
398+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
399+ // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
400+ // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
401+ // CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
402+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
403+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
404+ func.func @matmul_3_scalable (%arg0: vector <1 x2 xf32 >, %arg1: vector <[3 ]x1 xf32 >, %arg2: vector <2 x[3 ]xf32 >)
405+ -> vector <2 x[3 ]xf32 >
406+ {
407+ %0 = vector.contract #matmat_trait_3 %arg0 , %arg1 , %arg2
408+ : vector <1 x2 xf32 >, vector <[3 ]x1 xf32 > into vector <2 x[3 ]xf32 >
409+ return %0 : vector <2 x[3 ]xf32 >
410+ }
411+
289412#matmat_accesses_4 = [
290413 affine_map <(m , n , k ) -> (m , k )>,
291414 affine_map <(m , n , k ) -> (k , n )>,
@@ -313,6 +436,33 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
313436 return %0 : vector <3 x2 xf32 >
314437}
315438
439+ // CHECK-LABEL: func @matmul_4_scalable
440+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
441+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
442+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
443+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
444+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
445+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
446+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
447+ // CHECK: return %[[c0]] : vector<3x[2]xf32>
448+ func.func @matmul_4_scalable (%arg0: vector <[2 ]x1 xf32 >, %arg1: vector <1 x3 xf32 >, %arg2: vector <3 x[2 ]xf32 >)
449+ -> vector <3 x[2 ]xf32 >
450+ {
451+ %0 = vector.contract #matmat_trait_4 %arg0 , %arg1 , %arg2
452+ : vector <[2 ]x1 xf32 >, vector <1 x3 xf32 > into vector <3 x[2 ]xf32 >
453+ return %0 : vector <3 x[2 ]xf32 >
454+ }
455+
456+ #matmat_accesses_5 = [
457+ affine_map <(m , n , k ) -> (m , k )>,
458+ affine_map <(m , n , k ) -> (k , n )>,
459+ affine_map <(m , n , k ) -> (n , m )>
460+ ]
461+ #matmat_trait_5 = {
462+ indexing_maps = #matmat_accesses_5 ,
463+ iterator_types = [" parallel" , " parallel" , " reduction" ]
464+ }
465+
316466// CHECK-LABEL: @masked_matvec_mk_k_m
317467// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
318468// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
0 commit comments