Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3105,7 +3105,7 @@ FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
/// is the case if the all offsets are zero, all strides are 1, and the source
/// shape is same as the size of the subview. In such cases, the subview can
/// be folded into its source.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update the comment?

static bool isTrivialSubViewOp(SubViewOp subViewOp) {
static bool isTrivialSubViewOp(OpBuilder &b, SubViewOp subViewOp) {
if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
return false;

Expand All @@ -3127,15 +3127,24 @@ static bool isTrivialSubViewOp(SubViewOp subViewOp) {
}))
return false;

// Check all size values are static and matches the (static) source shape.
// Check all size values match the source shape.
ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
for (const auto &size : llvm::enumerate(mixedSizes)) {
std::optional<int64_t> intValue = getConstantIntValue(size.value());
if (!intValue || *intValue != sourceShape[size.index()])
return false;
if (llvm::all_of_zip(mixedSizes, sourceShape,
[](OpFoldResult mixedSize, int64_t staticSize) {
std::optional<int64_t> constSize =
getConstantIntValue(mixedSize);
return constSize.has_value() &&
*constSize == staticSize;
Comment on lines +3134 to +3137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use isConstantIntValue.

/// Return true if `ofr` is constant integer equal to `value`.
bool isConstantIntValue(OpFoldResult ofr, int64_t value);

})) {
return true;
}
// All conditions met. The `SubViewOp` is foldable as a no-op.
return true;
auto sourceOpResult = dyn_cast<OpResult>(subViewOp.getSource());
if (!sourceOpResult)
return false;
ReifiedRankedShapedTypeDims resultDims;
if (failed(reifyResultShapes(b, sourceOpResult.getOwner(), resultDims)))
return false;
return llvm::equal(mixedSizes, resultDims[sourceOpResult.getResultNumber()]);
}

namespace {
Expand Down Expand Up @@ -3206,7 +3215,7 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {

LogicalResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
if (!isTrivialSubViewOp(subViewOp))
if (!isTrivialSubViewOp(rewriter, subViewOp))
return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is problematic: isTrivialSubViewOp creates new IR, but then you return "failure". This is not allowed in a rewrite pattern. You have to make sure that there is no change to the input IR when you return "failure".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is difficult to compare the dynamic sizes without doing this. The other option I considered was using the ValueBoundsOpInterface, but that seems like an expensive check to run in a canonicalization. It would be nice if there were a version of reifyResultShapes that did not create IR, and returned failure if creating new IR was necessary, but that seems like a lot of work to plumb through.

These are the ways I know of to compare the result sizes of a tensor, but if you have any better suggestions, then that would be very helpful!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking into it further, I do see some instances of ValueBoundsOpInterface being used in canonicalization patterns, so I suppose it is the lesser of the two evils. I will update the PR implementation to use it instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this won't work because I don't think it can capture equality of fully dynamic sizes that come from the same SSA value. It is more about computing bounds. So I am back to looking for something like reifyResultShapes, but without creating extra IR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't use ValueBoundsOpInterface during canonicalization. It is quite expensive and should probably removed from the canonicalization patterns that use it today.

I don't have a good solution. Maybe we can build a new pass that folds various index computations and view-like ops based on a single ValueBoundsOpInterface analysis. Maybe that pass could then also replace the "reify ranked shaped ..." pass. Not sure if it's a good idea, just trying to think of a solution...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have an interface that does what I need in downstream IREE, actually. It is like reifyResultShapes, but "read only", as in, it doesn't create new IR. I am going to close this PR and handle this downstream for the time being, since I don't see a great way to do this here in MLIR.

if (subViewOp.getSourceType() == subViewOp.getType()) {
rewriter.replaceOp(subViewOp, subViewOp.getSource());
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4

// -----

// CHECK-LABEL: func @subview_of_dynamic_full_size
// CHECK-SAME: %[[ARG0:.+]]: memref<?xi8>
// CHECK-SAME: %[[SIZE:.+]]: index
// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape
// CHECK-NOT: memref.subview
// CHECK: return %[[EXPAND_SHAPE]] : memref<?x?xi8>
func.func @subview_of_dynamic_full_size(%arg0 : memref<?xi8>, %size : index) -> memref<?x?xi8> {
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%size, %size] : memref<?xi8> into memref<?x?xi8>
%1 = memref.subview %0[0, 0] [%size, %size] [1, 1] : memref<?x?xi8> to memref<?x?xi8>
return %1 : memref<?x?xi8>
}

// -----

// CHECK-LABEL: func @negative_subview_of_static_full_size
// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
// CHECK-SAME: %[[IDX:.+]]: index
Expand Down