@@ -212,25 +212,25 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
212212// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
213213// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
214214
215- // CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
216- // CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32 >, %[[ARG2:.+]]: vector<8x[8]xi32 >, %[[MASK:.+]]: vector<1x8x[8]x4xi1 >)
217- // CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1 > to vector<8x[8]x4xi1 >
215+ // CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
216+ // CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32 >, %[[ARG2:.+]]: vector<8x8xi32 >, %[[MASK:.+]]: vector<1x8x8x4xi1 >)
217+ // CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1 > to vector<8x8x4xi1 >
218218// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
219219// CHECK-SAME: vector.contract
220220// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
221221// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
222- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32 > into vector<8x[8]xi32 >
223- func.func @contract_broadcast_unit_dim_reduction_masked_scalable (%arg0 : vector <8 x4 xi32 >, %arg1 : vector <[ 8 ]x 4 x i32 >, %arg2 : vector <8 x[ 8 ]x i32 >, %mask: vector <1 x 8 x[ 8 ]x 4 x i1 >) -> vector <8 x[ 8 ]x i32 > {
222+ // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32 > into vector<8x8xi32 >
223+ func.func @contract_broadcast_unit_dim_reduction_masked (%arg0 : vector <8 x4 xi32 >, %arg1 : vector <8 x 4 x i32 >, %arg2 : vector <8 x 8 x i32 >, %mask: vector <1 x 8 x 8 x 4 x i1 >) -> vector <8 x 8 x i32 > {
224224 %0 = vector.broadcast %arg0 : vector <8 x4 xi32 > to vector <1 x8 x4 xi32 >
225- %1 = vector.broadcast %arg1 : vector <[ 8 ]x 4 x i32 > to vector <1 x[ 8 ]x 4 x i32 >
225+ %1 = vector.broadcast %arg1 : vector <8 x 4 x i32 > to vector <1 x 8 x 4 x i32 >
226226 %result = vector.mask %mask {
227227 vector.contract {
228228 indexing_maps = [#map0 , #map1 , #map2 ],
229229 iterator_types = [" reduction" , " parallel" , " parallel" , " reduction" ],
230230 kind = #vector.kind <add >
231- } %0 , %1 , %arg2 : vector <1 x8 x4 xi32 >, vector <1 x[ 8 ]x 4 x i32 > into vector <8 x[ 8 ]x i32 >
232- } : vector <1 x 8 x[ 8 ]x 4 x i1 > -> vector <8 x[ 8 ]x i32 >
233- return %result : vector <8 x[ 8 ]x i32 >
231+ } %0 , %1 , %arg2 : vector <1 x8 x4 xi32 >, vector <1 x 8 x 4 x i32 > into vector <8 x 8 x i32 >
232+ } : vector <1 x 8 x 8 x 4 x i1 > -> vector <8 x 8 x i32 >
233+ return %result : vector <8 x 8 x i32 >
234234}
235235
236236// -----
@@ -245,25 +245,25 @@ func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<
245245// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
246246// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
247247
248- // CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
249- // CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32 >, %[[ARG2:.+]]: vector<8x8xi32 >, %[[MASK:.+]]: vector<1x8x8x4xi1 >)
250- // CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1 > to vector<8x8x4xi1 >
248+ // CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
249+ // CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32 >, %[[ARG2:.+]]: vector<8x[8]xi32 >, %[[MASK:.+]]: vector<1x8x[8]x4xi1 >)
250+ // CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1 > to vector<8x[8]x4xi1 >
251251// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
252252// CHECK-SAME: vector.contract
253253// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
254254// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
255- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32 > into vector<8x8xi32 >
256- func.func @contract_broadcast_unit_dim_reduction_masked (%arg0 : vector <8 x4 xi32 >, %arg1 : vector <8 x 4 x i32 >, %arg2 : vector <8 x 8 x i32 >, %mask: vector <1 x 8 x 8 x 4 x i1 >) -> vector <8 x 8 x i32 > {
255+ // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32 > into vector<8x[8]xi32 >
256+ func.func @contract_broadcast_unit_dim_reduction_masked_scalable (%arg0 : vector <8 x4 xi32 >, %arg1 : vector <[ 8 ]x 4 x i32 >, %arg2 : vector <8 x[ 8 ]x i32 >, %mask: vector <1 x 8 x[ 8 ]x 4 x i1 >) -> vector <8 x[ 8 ]x i32 > {
257257 %0 = vector.broadcast %arg0 : vector <8 x4 xi32 > to vector <1 x8 x4 xi32 >
258- %1 = vector.broadcast %arg1 : vector <8 x 4 x i32 > to vector <1 x 8 x 4 x i32 >
258+ %1 = vector.broadcast %arg1 : vector <[ 8 ]x 4 x i32 > to vector <1 x[ 8 ]x 4 x i32 >
259259 %result = vector.mask %mask {
260260 vector.contract {
261261 indexing_maps = [#map0 , #map1 , #map2 ],
262262 iterator_types = [" reduction" , " parallel" , " parallel" , " reduction" ],
263263 kind = #vector.kind <add >
264- } %0 , %1 , %arg2 : vector <1 x8 x4 xi32 >, vector <1 x 8 x 4 x i32 > into vector <8 x 8 x i32 >
265- } : vector <1 x 8 x 8 x 4 x i1 > -> vector <8 x 8 x i32 >
266- return %result : vector <8 x 8 x i32 >
264+ } %0 , %1 , %arg2 : vector <1 x8 x4 xi32 >, vector <1 x[ 8 ]x 4 x i32 > into vector <8 x[ 8 ]x i32 >
265+ } : vector <1 x 8 x[ 8 ]x 4 x i1 > -> vector <8 x[ 8 ]x i32 >
266+ return %result : vector <8 x[ 8 ]x i32 >
267267}
268268
269269// -----
0 commit comments