@@ -3105,7 +3105,7 @@ FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
31053105// / is the case if the all offsets are zero, all strides are 1, and the source
31063106// / shape is same as the size of the subview. In such cases, the subview can
31073107// / be folded into its source.
3108- static bool isTrivialSubViewOp (SubViewOp subViewOp) {
3108+ static bool isTrivialSubViewOp (OpBuilder &b, SubViewOp subViewOp) {
31093109 if (subViewOp.getSourceType ().getRank () != subViewOp.getType ().getRank ())
31103110 return false ;
31113111
@@ -3127,15 +3127,24 @@ static bool isTrivialSubViewOp(SubViewOp subViewOp) {
31273127 }))
31283128 return false ;
31293129
3130- // Check all size values are static and matches the (static) source shape.
3130+ // Check all size values match the source shape.
31313131 ArrayRef<int64_t > sourceShape = subViewOp.getSourceType ().getShape ();
3132- for (const auto &size : llvm::enumerate (mixedSizes)) {
3133- std::optional<int64_t > intValue = getConstantIntValue (size.value ());
3134- if (!intValue || *intValue != sourceShape[size.index ()])
3135- return false ;
3132+ if (llvm::all_of_zip (mixedSizes, sourceShape,
3133+ [](OpFoldResult mixedSize, int64_t staticSize) {
3134+ std::optional<int64_t > constSize =
3135+ getConstantIntValue (mixedSize);
3136+ return constSize.has_value () &&
3137+ *constSize == staticSize;
3138+ })) {
3139+ return true ;
31363140 }
3137- // All conditions met. The `SubViewOp` is foldable as a no-op.
3138- return true ;
3141+ auto sourceOpResult = dyn_cast<OpResult>(subViewOp.getSource ());
3142+ if (!sourceOpResult)
3143+ return false ;
3144+ ReifiedRankedShapedTypeDims resultDims;
3145+ if (failed (reifyResultShapes (b, sourceOpResult.getOwner (), resultDims)))
3146+ return false ;
3147+ return llvm::equal (mixedSizes, resultDims[sourceOpResult.getResultNumber ()]);
31393148}
31403149
31413150namespace {
@@ -3206,7 +3215,7 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
32063215
32073216 LogicalResult matchAndRewrite (SubViewOp subViewOp,
32083217 PatternRewriter &rewriter) const override {
3209- if (!isTrivialSubViewOp (subViewOp))
3218+ if (!isTrivialSubViewOp (rewriter, subViewOp))
32103219 return failure ();
32113220 if (subViewOp.getSourceType () == subViewOp.getType ()) {
32123221 rewriter.replaceOp (subViewOp, subViewOp.getSource ());
0 commit comments