1111//
1212// ===----------------------------------------------------------------------===//
1313
14+ #include " mlir/Dialect/Arith/Utils/Utils.h"
1415#include " mlir/Dialect/MemRef/Transforms/Passes.h"
1516
1617#include " mlir/Dialect/Affine/IR/AffineOps.h"
1718#include " mlir/Dialect/MemRef/IR/MemRef.h"
1819#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
1920#include " mlir/Dialect/Tensor/IR/Tensor.h"
21+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
22+ #include " mlir/IR/BuiltinTypeInterfaces.h"
23+ #include " mlir/IR/OpDefinition.h"
24+ #include " mlir/IR/TypeUtilities.h"
2025#include " mlir/Interfaces/DestinationStyleOpInterface.h"
2126#include " mlir/Interfaces/InferTypeOpInterface.h"
27+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28+ #include " llvm/ADT/STLExtras.h"
29+ #include " llvm/Support/Casting.h"
2230#include " llvm/Support/InterleavedRange.h"
2331
2432#define DEBUG_TYPE " reify-result-shapes"
@@ -49,85 +57,15 @@ static LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
4957 return op->emitWarning () << " failed to get the reified shapes" ;
5058 }
5159
52- bool modified = false ;
53- // Compute the new output types.
54- SmallVector<Type> outTypes;
55- for (const auto &[oldTy, reifiedShape] :
56- llvm::zip (op->getResultTypes (), reifiedResultShapes)) {
57- // Skip if it's not a memref or tensor type.
58- if (!isa<RankedTensorType, MemRefType>(oldTy)) {
59- outTypes.push_back (oldTy);
60- continue ;
61- }
62-
63- ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
64-
65- SmallVector<int64_t > shape = llvm::to_vector (shapedTy.getShape ());
66- for (auto &&[dim, ofr] : llvm::zip_equal (shape, reifiedShape)) {
67- std::optional<int64_t > maybeCst = getConstantIntValue (ofr);
68- // If the reified dim is dynamic set it appropriately.
69- if (!maybeCst.has_value ()) {
70- dim = ShapedType::kDynamic ;
71- continue ;
72- }
73- // Set the static dim.
74- dim = *maybeCst;
75- }
76-
77- // If the shape didn't change continue.
78- if (shape == shapedTy.getShape ()) {
79- outTypes.push_back (oldTy);
80- continue ;
81- }
82- modified = true ;
83- outTypes.push_back (shapedTy.cloneWith (shape, shapedTy.getElementType ()));
60+ for (auto [idx, reifiedShape] : llvm::enumerate (reifiedResultShapes)) {
61+ SmallVector<Value> vals =
62+ getValueOrCreateConstantIndexOp (rewriter, op->getLoc (), reifiedShape);
63+ vals.insert (vals.begin (), op->getResult (idx));
64+ OperationState state (op->getLoc (), " transform.materialize_shape" );
65+ state.addOperands (vals);
66+ rewriter.create (state);
8467 }
8568
86- // Return if we don't need to update.
87- if (!modified) {
88- LLVM_DEBUG ({ DBGS () << " - op doesn't require update\n " ; });
89- return success ();
90- }
91-
92- LLVM_DEBUG ({
93- DBGS () << " - oldTypes: " << llvm::interleaved_array (op->getResultTypes ())
94- << " \n " ;
95- DBGS () << " - outTypes: " << llvm::interleaved_array (outTypes) << " \n " ;
96- });
97-
98- // We now have outTypes that need to be turned to cast ops.
99- Location loc = op->getLoc ();
100- SmallVector<Value> newResults;
101- // TODO: `mlir::reifyResultShapes` and op verifiers may not agree atm.
102- // This is a confluence problem that will need to be addressed.
103- // For now, we know PadOp and ConcatOp are fine.
104- assert ((isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation ())) &&
105- " incorrect op" );
106- Operation *newOp = rewriter.clone (*op);
107- for (auto [reifiedTy, oldRes] : llvm::zip (outTypes, op->getResults ())) {
108- OpResult newRes = newOp->getResult (oldRes.getResultNumber ());
109- Type oldTy = oldRes.getType ();
110- // Continue if the type remained invariant or is not shaped.
111- if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
112- newResults.push_back (newRes);
113- continue ;
114- }
115-
116- // Update the type.
117- newRes.setType (reifiedTy);
118- if (isa<RankedTensorType>(reifiedTy)) {
119- newResults.push_back (rewriter.create <tensor::CastOp>(loc, oldTy, newRes));
120- } else {
121- assert (isa<MemRefType>(reifiedTy) && " expected a memref type" );
122- newResults.push_back (rewriter.create <memref::CastOp>(loc, oldTy, newRes));
123- }
124- }
125-
126- LLVM_DEBUG ({
127- DBGS () << " - reified results " << llvm::interleaved_array (newResults)
128- << " \n " ;
129- });
130- rewriter.replaceOp (op, newResults);
13169 return success ();
13270}
13371
@@ -143,17 +81,80 @@ struct ReifyResultShapesPass final
14381} // namespace
14482
14583void ReifyResultShapesPass::runOnOperation () {
84+ // 1. Select ops that are not DPS and that do not carry an tied operand
85+ // shapes. For now, limit to tensor::PadOp and tensor::ConcatOp.
14686 SmallVector<ReifyRankedShapedTypeOpInterface> ops;
14787 getOperation ()->walk ([&](ReifyRankedShapedTypeOpInterface op) {
148- // Handle ops that are not DPS and that do not carry an tied operand shapes.
149- // For now, limit to tensor::PadOp and tensor::ConcatOp.
15088 if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation ()))
15189 return ;
15290 ops.push_back (op);
15391 });
92+
93+ // 2. Insert materialization points to tie the result tensor to its shape
94+ // components as SSA values.
15495 IRRewriter rewriter (&getContext ());
15596 for (ReifyRankedShapedTypeOpInterface op : ops) {
15697 rewriter.setInsertionPoint (op);
15798 (void )reifyOpResultShapes (rewriter, op);
15899 }
100+
101+ // 3. Resolve ranked shapes greedily for all other ops that implement
102+ // ReifyRankedShapedTypeOpInterface, achieving propagation of information.
103+ RewritePatternSet patterns (&getContext ());
104+ memref::populateResolveRankedShapedTypeResultDimsPatterns (patterns);
105+ memref::populateResolveShapedTypeResultDimsPatterns (patterns);
106+ if (failed (applyPatternsGreedily (getOperation (), std::move (patterns))))
107+ return signalPassFailure ();
108+
109+ // 4. Process the information in the materialization points if more static
110+ // information is now available.
111+ getOperation ()->walk ([&](Operation *op) {
112+ if (op->getName ().getStringRef () != " transform.materialize_shape" )
113+ return ;
114+ auto resultShapedVal = cast<OpResult>(op->getOperands ().front ());
115+
116+ // 4.a. Fold information propagated to AffineApplyOp.
117+ SmallVector<OpFoldResult> ofrs =
118+ getAsOpFoldResult (op->getOperands ().drop_front ());
119+ for (auto &ofr : ofrs) {
120+ if (isa<Attribute>(ofr))
121+ continue ;
122+ if (auto affineApplyOp =
123+ (cast<Value>(ofr).getDefiningOp <affine::AffineApplyOp>())) {
124+ OpFoldResult o = affine::makeComposedFoldedAffineApply (
125+ rewriter, affineApplyOp->getLoc (), affineApplyOp.getAffineMap (),
126+ getAsOpFoldResult (affineApplyOp->getOperands ()),
127+ /* composeAffineMin=*/ true );
128+ if (isa<Attribute>(o))
129+ ofr = o;
130+ }
131+ }
132+
133+ // 4.b. Erase the materialization point.
134+ rewriter.eraseOp (op);
135+
136+ // 4.c. Clone the op and insert a better ShapeCastOp if the shape becomes
137+ // strictly more static.
138+ auto nst = cast<ShapedType>(resultShapedVal.getType ());
139+ nst = nst.cloneWith (getInducedShape (ofrs), getElementTypeOrSelf (nst));
140+ Operation *oldOp = resultShapedVal.getDefiningOp ();
141+ assert (llvm::isa_and_nonnull<ReifyRankedShapedTypeOpInterface>(oldOp));
142+ // 4.c.i. If the shape did not change, bail.
143+ auto onst = cast<ShapedType>(
144+ oldOp->getResultTypes ()[resultShapedVal.getResultNumber ()]);
145+ if (onst == nst)
146+ return ;
147+
148+ // 4.c.ii. If any shape dimension becomes less static, bail.
149+ for (auto [ns, os] : llvm::zip_equal (nst.getShape (), onst.getShape ())) {
150+ if (ShapedType::isDynamic (ns) && !ShapedType::isDynamic (os))
151+ return ;
152+ }
153+
154+ // 4.c.ii. RAUW
155+ Operation *newOp = rewriter.clone (*oldOp);
156+ OpResult newRes = newOp->getResult (resultShapedVal.getResultNumber ());
157+ newRes.setType (nst);
158+ rewriter.replaceAllUsesWith (resultShapedVal, newRes);
159+ });
159160}
0 commit comments