Skip to content

Commit af93380

Browse files
[WIP] More powerful form of ReifyResultShapes
1 parent 00ee6bd commit af93380

File tree

3 files changed

+99
-79
lines changed

3 files changed

+99
-79
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
110110
/// If all ofrs are constant integers or IntegerAttrs, return the integers.
111111
std::optional<SmallVector<int64_t>>
112112
getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
113+
/// Return a shape induced by ofrs, with ShapedType::kDynamic encoding dynamic
114+
/// Values.
115+
SmallVector<int64_t> getInducedShape(ArrayRef<OpFoldResult> ofrs);
113116

114117
/// Return true if `ofr` is constant integer equal to `value`.
115118
bool isConstantIntValue(OpFoldResult ofr, int64_t value);

mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp

Lines changed: 80 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,22 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1415
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1516

1617
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1718
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1819
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1920
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
22+
#include "mlir/IR/BuiltinTypeInterfaces.h"
23+
#include "mlir/IR/OpDefinition.h"
24+
#include "mlir/IR/TypeUtilities.h"
2025
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2126
#include "mlir/Interfaces/InferTypeOpInterface.h"
27+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28+
#include "llvm/ADT/STLExtras.h"
29+
#include "llvm/Support/Casting.h"
2230
#include "llvm/Support/InterleavedRange.h"
2331

2432
#define DEBUG_TYPE "reify-result-shapes"
@@ -49,85 +57,15 @@ static LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
4957
return op->emitWarning() << "failed to get the reified shapes";
5058
}
5159

52-
bool modified = false;
53-
// Compute the new output types.
54-
SmallVector<Type> outTypes;
55-
for (const auto &[oldTy, reifiedShape] :
56-
llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
57-
// Skip if it's not a memref or tensor type.
58-
if (!isa<RankedTensorType, MemRefType>(oldTy)) {
59-
outTypes.push_back(oldTy);
60-
continue;
61-
}
62-
63-
ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
64-
65-
SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
66-
for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
67-
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
68-
// If the reified dim is dynamic set it appropriately.
69-
if (!maybeCst.has_value()) {
70-
dim = ShapedType::kDynamic;
71-
continue;
72-
}
73-
// Set the static dim.
74-
dim = *maybeCst;
75-
}
76-
77-
// If the shape didn't change continue.
78-
if (shape == shapedTy.getShape()) {
79-
outTypes.push_back(oldTy);
80-
continue;
81-
}
82-
modified = true;
83-
outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
60+
for (auto [idx, reifiedShape] : llvm::enumerate(reifiedResultShapes)) {
61+
SmallVector<Value> vals =
62+
getValueOrCreateConstantIndexOp(rewriter, op->getLoc(), reifiedShape);
63+
vals.insert(vals.begin(), op->getResult(idx));
64+
OperationState state(op->getLoc(), "transform.materialize_shape");
65+
state.addOperands(vals);
66+
rewriter.create(state);
8467
}
8568

86-
// Return if we don't need to update.
87-
if (!modified) {
88-
LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
89-
return success();
90-
}
91-
92-
LLVM_DEBUG({
93-
DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
94-
<< " \n";
95-
DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
96-
});
97-
98-
// We now have outTypes that need to be turned to cast ops.
99-
Location loc = op->getLoc();
100-
SmallVector<Value> newResults;
101-
// TODO: `mlir::reifyResultShapes` and op verifiers may not agree atm.
102-
// This is a confluence problem that will need to be addressed.
103-
// For now, we know PadOp and ConcatOp are fine.
104-
assert((isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation())) &&
105-
"incorrect op");
106-
Operation *newOp = rewriter.clone(*op);
107-
for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
108-
OpResult newRes = newOp->getResult(oldRes.getResultNumber());
109-
Type oldTy = oldRes.getType();
110-
// Continue if the type remained invariant or is not shaped.
111-
if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
112-
newResults.push_back(newRes);
113-
continue;
114-
}
115-
116-
// Update the type.
117-
newRes.setType(reifiedTy);
118-
if (isa<RankedTensorType>(reifiedTy)) {
119-
newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
120-
} else {
121-
assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
122-
newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
123-
}
124-
}
125-
126-
LLVM_DEBUG({
127-
DBGS() << "- reified results " << llvm::interleaved_array(newResults)
128-
<< "\n";
129-
});
130-
rewriter.replaceOp(op, newResults);
13169
return success();
13270
}
13371

@@ -143,17 +81,80 @@ struct ReifyResultShapesPass final
14381
} // namespace
14482

14583
void ReifyResultShapesPass::runOnOperation() {
84+
// 1. Select ops that are not DPS and that do not carry an tied operand
85+
// shapes. For now, limit to tensor::PadOp and tensor::ConcatOp.
14686
SmallVector<ReifyRankedShapedTypeOpInterface> ops;
14787
getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
148-
// Handle ops that are not DPS and that do not carry an tied operand shapes.
149-
// For now, limit to tensor::PadOp and tensor::ConcatOp.
15088
if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
15189
return;
15290
ops.push_back(op);
15391
});
92+
93+
// 2. Insert materialization points to tie the result tensor to its shape
94+
// components as SSA values.
15495
IRRewriter rewriter(&getContext());
15596
for (ReifyRankedShapedTypeOpInterface op : ops) {
15697
rewriter.setInsertionPoint(op);
15798
(void)reifyOpResultShapes(rewriter, op);
15899
}
100+
101+
// 3. Resolve ranked shapes greedily for all other ops that implement
102+
// ReifyRankedShapedTypeOpInterface, achieving propagation of information.
103+
RewritePatternSet patterns(&getContext());
104+
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
105+
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
106+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
107+
return signalPassFailure();
108+
109+
// 4. Process the information in the materialization points if more static
110+
// information is now available.
111+
getOperation()->walk([&](Operation *op) {
112+
if (op->getName().getStringRef() != "transform.materialize_shape")
113+
return;
114+
auto resultShapedVal = cast<OpResult>(op->getOperands().front());
115+
116+
// 4.a. Fold information propagated to AffineApplyOp.
117+
SmallVector<OpFoldResult> ofrs =
118+
getAsOpFoldResult(op->getOperands().drop_front());
119+
for (auto &ofr : ofrs) {
120+
if (isa<Attribute>(ofr))
121+
continue;
122+
if (auto affineApplyOp =
123+
(cast<Value>(ofr).getDefiningOp<affine::AffineApplyOp>())) {
124+
OpFoldResult o = affine::makeComposedFoldedAffineApply(
125+
rewriter, affineApplyOp->getLoc(), affineApplyOp.getAffineMap(),
126+
getAsOpFoldResult(affineApplyOp->getOperands()),
127+
/*composeAffineMin=*/true);
128+
if (isa<Attribute>(o))
129+
ofr = o;
130+
}
131+
}
132+
133+
// 4.b. Erase the materialization point.
134+
rewriter.eraseOp(op);
135+
136+
// 4.c. Clone the op and insert a better ShapeCastOp if the shape becomes
137+
// strictly more static.
138+
auto nst = cast<ShapedType>(resultShapedVal.getType());
139+
nst = nst.cloneWith(getInducedShape(ofrs), getElementTypeOrSelf(nst));
140+
Operation *oldOp = resultShapedVal.getDefiningOp();
141+
assert(llvm::isa_and_nonnull<ReifyRankedShapedTypeOpInterface>(oldOp));
142+
// 4.c.i. If the shape did not change, bail.
143+
auto onst = cast<ShapedType>(
144+
oldOp->getResultTypes()[resultShapedVal.getResultNumber()]);
145+
if (onst == nst)
146+
return;
147+
148+
// 4.c.ii. If any shape dimension becomes less static, bail.
149+
for (auto [ns, os] : llvm::zip_equal(nst.getShape(), onst.getShape())) {
150+
if (ShapedType::isDynamic(ns) && !ShapedType::isDynamic(os))
151+
return;
152+
}
153+
154+
// 4.c.ii. RAUW
155+
Operation *newOp = rewriter.clone(*oldOp);
156+
OpResult newRes = newOp->getResult(resultShapedVal.getResultNumber());
157+
newRes.setType(nst);
158+
rewriter.replaceAllUsesWith(resultShapedVal, newRes);
159+
});
159160
}

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,22 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
141141
return res;
142142
}
143143

144+
SmallVector<int64_t> getInducedShape(ArrayRef<OpFoldResult> ofrs) {
145+
SmallVector<int64_t> shape;
146+
shape.resize_for_overwrite(ofrs.size());
147+
for (auto &&[dim, ofr] : llvm::zip_equal(shape, ofrs)) {
148+
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
149+
// If the reified dim is dynamic set it appropriately.
150+
if (!maybeCst.has_value()) {
151+
dim = ShapedType::kDynamic;
152+
continue;
153+
}
154+
// Set the static dim.
155+
dim = *maybeCst;
156+
}
157+
return shape;
158+
}
159+
144160
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
145161
return getConstantIntValue(ofr) == value;
146162
}

0 commit comments

Comments
 (0)