diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index fb2921fec9f79..792e722918306 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -131,11 +132,25 @@ struct IterArgsToInitArgs : public OpRewritePattern { auto blockArg = dyn_cast(dimOp.getSource()); if (!blockArg) return failure(); - auto loopLikeOp = - dyn_cast(blockArg.getParentBlock()->getParentOp()); - if (!loopLikeOp) + // TODO: Enable this for loopLikeInterface. Restricting for scf.for + // because the init args shape might change in the loop body. + // For e.g.: + // ``` + // %0 = tensor.empty(%c1) : tensor + // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) -> + // tensor { + // %1 = tensor.dim %arg0, %c0 : tensor + // %2 = arith.addi %c1, %1 : index + // %3 = tensor.empty(%2) : tensor + // scf.yield %3 : tensor + // } + // + // ``` + auto forAllOp = + dyn_cast(blockArg.getParentBlock()->getParentOp()); + if (!forAllOp) return failure(); - Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get(); + Value initArg = forAllOp.getTiedLoopInit(blockArg)->get(); rewriter.modifyOpInPlace( dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); }); return success();