Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,40 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,

if (!mask) {
LDBG() << "No mask required";
if (assumeDynamicDimsMatchVecSizes) {
llvm::TypeSwitch<Operation *>(opToMask)
.Case<vector::TransferReadOp, vector::TransferWriteOp>(
[&](auto xferOp) {
// For vector.transfer_read and vector.transfer_write, there is
// also the `in-bounds` attribute that has to be set explicitly
// to true. Otherwise, "out-of-bounds" access will be assumed
// and masks will be generated while lowering these.
LDBG() << "Assuming dynamic dimensions match vector sizes and "
"setting their in-bounds to true!";
SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
ShapedType xferType = xferOp.getShapedType();
AffineMap permMap = xferOp.getPermutationMap();
// Only set the in-bounds values to true for dynamic dims.
// Different mechanisms will set these accordingly for the
// static dims.
for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
// Skip broadcast dimensions.
if (!dimExpr)
continue;
unsigned pos = dimExpr.getPosition();
if (xferType.isDynamicDim(pos))
inBoundsMap[i] = true;
}
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.setInBoundsAttr(
rewriter.getBoolArrayAttr(inBoundsMap));
});
})
.Default([](Operation *op) {
// No-op if the operation is not an xfer read or write.
});
}
return opToMask;
}

Expand Down
26 changes: 18 additions & 8 deletions mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -918,12 +918,17 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
// CHECK-NOT: mask
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]
// CHECK-SAME: memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]
// `in-bounds` are set to true for dynamic dims with assume, static sizes will be inferred elsewhere.
// CHECK-SAME: in_bounds = [false, false, false, false, true, false]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]
// CHECK-SAME: in_bounds = [false, false, false, true]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]
// CHECK-SAME: in_bounds = [false, false, false, true]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
Expand Down Expand Up @@ -1011,12 +1016,17 @@ func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: mem
// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
// CHECK-NOT: mask
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]
// CHECK-SAME: memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]
// `in-bounds` are set to true for dynamic dims with assume, static sizes will be inferred elsewhere.
// CHECK-SAME: in_bounds = [false, false, false, false, false, true, false]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]
// CHECK-SAME: in_bounds = [false, false, false, false, true]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]
// CHECK-SAME: in_bounds = [false, false, false, false, true]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
Expand Down