Skip to content

Commit 581dfa8

Browse files
egebeyselsvkeerthy
authored andcommitted
[mlir][linalg] set inbounds on xfer_read/writes for assumeDynamicDimsMatchVecSizes (#160839)
The idea from #146531 was to introduce the flag `assumeDynamicDimsMatchVecSizes`, to signal the vectorizer that the access should not be masked and is in-bounds. Though the masking part is handled, `xfer_read/write` ops are created without explicitly setting the inbounds attribute, which defaults to all-false. In the existence of scalable tile sizes, subsequent patterns tend to overwrite the inbounds attribute and introduce masks further down when lowered to loads and stores. This PR explicitly sets the inbounds attribute to all-true for `xfer_read/write` ops if the `assumeDynamicDimsMatchVecSizes` flag is set. --------- Signed-off-by: Ege Beysel <[email protected]>
1 parent 4f3eb0e commit 581dfa8

File tree

2 files changed

+52
-8
lines changed

2 files changed

+52
-8
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,40 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
524524

525525
if (!mask) {
526526
LDBG() << "No mask required";
527+
if (assumeDynamicDimsMatchVecSizes) {
528+
llvm::TypeSwitch<Operation *>(opToMask)
529+
.Case<vector::TransferReadOp, vector::TransferWriteOp>(
530+
[&](auto xferOp) {
531+
// For vector.transfer_read and vector.transfer_write, there is
532+
// also the `in-bounds` attribute that has to be set explicitly
533+
// to true. Otherwise, "out-of-bounds" access will be assumed
534+
// and masks will be generated while lowering these.
535+
LDBG() << "Assuming dynamic dimensions match vector sizes and "
536+
"setting their in-bounds to true!";
537+
SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
538+
ShapedType xferType = xferOp.getShapedType();
539+
AffineMap permMap = xferOp.getPermutationMap();
540+
// Only set the in-bounds values to true for dynamic dims.
541+
// Different mechanisms will set these accordingly for the
542+
// static dims.
543+
for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
544+
auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
545+
// Skip broadcast dimensions.
546+
if (!dimExpr)
547+
continue;
548+
unsigned pos = dimExpr.getPosition();
549+
if (xferType.isDynamicDim(pos))
550+
inBoundsMap[i] = true;
551+
}
552+
rewriter.modifyOpInPlace(xferOp, [&]() {
553+
xferOp.setInBoundsAttr(
554+
rewriter.getBoolArrayAttr(inBoundsMap));
555+
});
556+
})
557+
.Default([](Operation *op) {
558+
// No-op if the operation is not an xfer read or write.
559+
});
560+
}
527561
return opToMask;
528562
}
529563

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -918,12 +918,17 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
918918
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
919919
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
920920
// CHECK-NOT: mask
921-
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
922-
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
923-
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
921+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]
922+
// CHECK-SAME: memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
923+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]
924+
// `in-bounds` are set to true for dynamic dims with assume, static sizes will be inferred elsewhere.
925+
// CHECK-SAME: in_bounds = [false, false, false, false, true, false]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
926+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]
927+
// CHECK-SAME: in_bounds = [false, false, false, true]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924928
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925929
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926-
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
930+
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]
931+
// CHECK-SAME: in_bounds = [false, false, false, true]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
927932

928933
module attributes {transform.with_named_sequence} {
929934
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -1011,12 +1016,17 @@ func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: mem
10111016
// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
10121017
// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
10131018
// CHECK-NOT: mask
1014-
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1015-
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1016-
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
1019+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]
1020+
// CHECK-SAME: memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1021+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]
1022+
// `in-bounds` are set to true for dynamic dims with assume, static sizes will be inferred elsewhere.
1023+
// CHECK-SAME: in_bounds = [false, false, false, false, false, true, false]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
1024+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]
1025+
// CHECK-SAME: in_bounds = [false, false, false, false, true]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
10171026
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
10181027
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
1019-
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
1028+
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]
1029+
// CHECK-SAME: in_bounds = [false, false, false, false, true]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
10201030

10211031
module attributes {transform.with_named_sequence} {
10221032
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {

0 commit comments

Comments
 (0)