Skip to content

Commit dc0b38f

Browse files
committed
[mlir] Allow folding dynamic full size subviews
Signed-off-by: Max Dawkins <[email protected]>
1 parent 10d198b commit dc0b38f

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3105,7 +3105,7 @@ FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
31053105
/// is the case if the all offsets are zero, all strides are 1, and the source
31063106
/// shape is same as the size of the subview. In such cases, the subview can
31073107
/// be folded into its source.
3108-
static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3108+
static bool isTrivialSubViewOp(OpBuilder &b, SubViewOp subViewOp) {
31093109
if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
31103110
return false;
31113111

@@ -3127,15 +3127,24 @@ static bool isTrivialSubViewOp(SubViewOp subViewOp) {
31273127
}))
31283128
return false;
31293129

3130-
// Check all size values are static and matches the (static) source shape.
3130+
// Check all size values match the source shape.
31313131
ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3132-
for (const auto &size : llvm::enumerate(mixedSizes)) {
3133-
std::optional<int64_t> intValue = getConstantIntValue(size.value());
3134-
if (!intValue || *intValue != sourceShape[size.index()])
3135-
return false;
3132+
if (llvm::all_of_zip(mixedSizes, sourceShape,
3133+
[](OpFoldResult mixedSize, int64_t staticSize) {
3134+
std::optional<int64_t> constSize =
3135+
getConstantIntValue(mixedSize);
3136+
return constSize.has_value() &&
3137+
*constSize == staticSize;
3138+
})) {
3139+
return true;
31363140
}
3137-
// All conditions met. The `SubViewOp` is foldable as a no-op.
3138-
return true;
3141+
auto sourceOpResult = dyn_cast<OpResult>(subViewOp.getSource());
3142+
if (!sourceOpResult)
3143+
return false;
3144+
ReifiedRankedShapedTypeDims resultDims;
3145+
if (failed(reifyResultShapes(b, sourceOpResult.getOwner(), resultDims)))
3146+
return false;
3147+
return llvm::equal(mixedSizes, resultDims[sourceOpResult.getResultNumber()]);
31393148
}
31403149

31413150
namespace {
@@ -3206,7 +3215,7 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
32063215

32073216
LogicalResult matchAndRewrite(SubViewOp subViewOp,
32083217
PatternRewriter &rewriter) const override {
3209-
if (!isTrivialSubViewOp(subViewOp))
3218+
if (!isTrivialSubViewOp(rewriter, subViewOp))
32103219
return failure();
32113220
if (subViewOp.getSourceType() == subViewOp.getType()) {
32123221
rewriter.replaceOp(subViewOp, subViewOp.getSource());

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
7070

7171
// -----
7272

73+
// CHECK-LABEL: func @subview_of_dynamic_full_size
74+
// CHECK-SAME: %[[ARG0:.+]]: memref<?xi8>
75+
// CHECK-SAME: %[[SIZE:.+]]: index
76+
// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape
77+
// CHECK-NOT: memref.subview
78+
// CHECK: return %[[EXPAND_SHAPE]] : memref<?x?xi8>
79+
func.func @subview_of_dynamic_full_size(%arg0 : memref<?xi8>, %size : index) -> memref<?x?xi8> {
80+
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%size, %size] : memref<?xi8> into memref<?x?xi8>
81+
%1 = memref.subview %0[0, 0] [%size, %size] [1, 1] : memref<?x?xi8> to memref<?x?xi8>
82+
return %1 : memref<?x?xi8>
83+
}
84+
85+
// -----
86+
7387
// CHECK-LABEL: func @negative_subview_of_static_full_size
7488
// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
7589
// CHECK-SAME: %[[IDX:.+]]: index

0 commit comments

Comments
 (0)