Skip to content

Commit 9987bcc

Browse files
committed
[mlir][linalg] set inbounds to all-true for xfer reads/writes for assumeDynamicDimsMatchVecSizes
1 parent 010f96a commit 9987bcc

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,23 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
524524

525525
if (!mask) {
526526
LDBG() << "No mask required";
527+
if (assumeDynamicDimsMatchVecSizes) {
528+
LDBG() << "Assuming dynamic dimensions match vector sizes!";
529+
// Set inbounds to all-true.
530+
llvm::TypeSwitch<Operation *>(opToMask)
531+
.Case<vector::TransferReadOp, vector::TransferWriteOp>(
532+
[&](auto xferOp) {
533+
SmallVector<bool> inBoundsMap(xferOp.getInBounds().size(),
534+
true);
535+
rewriter.modifyOpInPlace(xferOp, [&]() {
536+
xferOp.setInBoundsAttr(
537+
rewriter.getBoolArrayAttr(inBoundsMap));
538+
});
539+
})
540+
.Default([](Operation *op) {
541+
// No-op if the operation is not an xfer read or write.
542+
});
543+
}
527544
return opToMask;
528545
}
529546

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -918,12 +918,16 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
918918
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
919919
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
920920
// CHECK-NOT: mask
921-
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
922-
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
923-
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
921+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]
922+
// CHECK-SAME: in_bounds = [true, true, true, true, true, true]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
923+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]
924+
// CHECK-SAME: in_bounds = [true, true, true, true, true, true]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
925+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]
926+
// CHECK-SAME: in_bounds = [true, true, true, true]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924927
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925928
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926-
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
929+
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]
930+
// CHECK-SAME: in_bounds = [true, true, true, true]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
927931

928932
module attributes {transform.with_named_sequence} {
929933
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -1011,12 +1015,16 @@ func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: mem
10111015
// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
10121016
// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
10131017
// CHECK-NOT: mask
1014-
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1015-
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1016-
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
1018+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]
1019+
// CHECK-SAME: in_bounds = [true, true, true, true, true, true, true]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1020+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]
1021+
// CHECK-SAME: in_bounds = [true, true, true, true, true, true, true]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1022+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]
1023+
// CHECK-SAME: in_bounds = [true, true, true, true, true]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
10171024
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
10181025
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
1019-
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
1026+
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]
1027+
// CHECK-SAME: in_bounds = [true, true, true, true, true]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
10201028

10211029
module attributes {transform.with_named_sequence} {
10221030
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {

0 commit comments

Comments
 (0)