Skip to content

Commit 7cdbd8a

Browse files
committed
fixup! fixup! [mlir][vector] Update CombineContractBroadcastMask
Swap masked and scalable tests (thanks Hanhan)
1 parent f3e3aff commit 7cdbd8a

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -212,25 +212,25 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
212212
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
213213
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
214214

215-
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
216-
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32>, %[[ARG2:.+]]: vector<8x[8]xi32>, %[[MASK:.+]]: vector<1x8x[8]x4xi1>)
217-
// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1> to vector<8x[8]x4xi1>
215+
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
216+
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>)
217+
// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1>
218218
// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
219219
// CHECK-SAME: vector.contract
220220
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
221221
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
222-
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32> into vector<8x[8]xi32>
223-
func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<8x4xi32>, %arg1 : vector<[8]x4xi32>, %arg2 : vector<8x[8]xi32>, %mask: vector<1x8x[8]x4xi1>) -> vector<8x[8]xi32> {
222+
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
223+
func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
224224
%0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
225-
%1 = vector.broadcast %arg1 : vector<[8]x4xi32> to vector<1x[8]x4xi32>
225+
%1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
226226
%result = vector.mask %mask {
227227
vector.contract {
228228
indexing_maps = [#map0, #map1, #map2],
229229
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
230230
kind = #vector.kind<add>
231-
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x[8]x4xi32> into vector<8x[8]xi32>
232-
} : vector<1x8x[8]x4xi1> -> vector<8x[8]xi32>
233-
return %result : vector<8x[8]xi32>
231+
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
232+
} : vector<1x8x8x4xi1> -> vector<8x8xi32>
233+
return %result : vector<8x8xi32>
234234
}
235235

236236
// -----
@@ -245,25 +245,25 @@ func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<
245245
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
246246
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
247247

248-
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
249-
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>)
250-
// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1>
248+
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
249+
// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32>, %[[ARG2:.+]]: vector<8x[8]xi32>, %[[MASK:.+]]: vector<1x8x[8]x4xi1>)
250+
// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1> to vector<8x[8]x4xi1>
251251
// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
252252
// CHECK-SAME: vector.contract
253253
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
254254
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
255-
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
256-
func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
255+
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32> into vector<8x[8]xi32>
256+
func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<8x4xi32>, %arg1 : vector<[8]x4xi32>, %arg2 : vector<8x[8]xi32>, %mask: vector<1x8x[8]x4xi1>) -> vector<8x[8]xi32> {
257257
%0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
258-
%1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
258+
%1 = vector.broadcast %arg1 : vector<[8]x4xi32> to vector<1x[8]x4xi32>
259259
%result = vector.mask %mask {
260260
vector.contract {
261261
indexing_maps = [#map0, #map1, #map2],
262262
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
263263
kind = #vector.kind<add>
264-
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
265-
} : vector<1x8x8x4xi1> -> vector<8x8xi32>
266-
return %result : vector<8x8xi32>
264+
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x[8]x4xi32> into vector<8x[8]xi32>
265+
} : vector<1x8x[8]x4xi1> -> vector<8x[8]xi32>
266+
return %result : vector<8x[8]xi32>
267267
}
268268

269269
// -----

0 commit comments

Comments
 (0)