@@ -103,6 +103,44 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
103103 return success ();
104104 }
105105};
106+
107+ // / Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
108+ // /
109+ // / ```
110+ // / %0 = ... : tensor<?x?xf32>
111+ // / scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
112+ // / %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
113+ // / ...
114+ // / }
115+ // / ```
116+ // /
117+ // / is folded to:
118+ // /
119+ // / ```
120+ // / %0 = ... : tensor<?x?xf32>
121+ // / scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
122+ // / %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
123+ // / ...
124+ // / }
125+ // / ```
126+ struct IterArgsToInitArgs : public OpRewritePattern <tensor::DimOp> {
127+ using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
128+
129+ LogicalResult matchAndRewrite (tensor::DimOp dimOp,
130+ PatternRewriter &rewriter) const final {
131+ auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource ());
132+ if (!blockArg)
133+ return failure ();
134+ auto loopLikeOp =
135+ dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock ()->getParentOp ());
136+ if (!loopLikeOp)
137+ return failure ();
138+ Value initArg = loopLikeOp.getTiedLoopInit (blockArg)->get ();
139+ rewriter.modifyOpInPlace (
140+ dimOp, [&]() { dimOp.getSourceMutable ().assign (initArg); });
141+ return success ();
142+ }
143+ };
106144} // namespace
107145
108146// ===----------------------------------------------------------------------===//
@@ -127,8 +165,8 @@ struct ResolveShapedTypeResultDimsPass final
127165void memref::populateResolveRankedShapedTypeResultDimsPatterns (
128166 RewritePatternSet &patterns) {
129167 patterns.add <DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
130- DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
131- patterns.getContext ());
168+ DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
169+ IterArgsToInitArgs>( patterns.getContext ());
132170}
133171
134172void memref::populateResolveShapedTypeResultDimsPatterns (
0 commit comments