Skip to content

Commit fc093f1

Browse files
[mlir][Interfaces] Add interface methods to allow reifying single result/single dim of result. (#162924)
Current implementation of `reifyResultShapes` forces all implementations to return all dimensions of all results. This can be wasteful when you only require dimensions of one result, or a single dimension of a result. Further this also creates issues with using patterns to resolve the `tensor.dim` and `memref.dim` operations since the extra operations created result in the pattern rewriter entering an infinite loop (eventually breaking out of the loop due to the iteration limit on the pattern rewriter). This is demonstrated by some of the test cases added here that hit this limit when using `--resolve-shaped-type-result-dims` and `--resolve-ranked-shaped-type-result-dims`. To resolve this issue the interface should allow for creating just the operations needed. This change is the first step in resolving this. The original implementation was done with the restriction in mind that it might not always be possible to compute dimension of a single result or one dimension of a single result in all cases. To account for such cases, two additional interface methods are added - `reifyShapeOfResult` (which allows reifying dimensions of just one result), has a default implementation that calls `reifyResultShapes` and returns the dimensions of a single result. - `reifyDimOfResult` (which allows reifying a single dimension of a single result) has a default implementation that calls `reifyDimOfResult` and returns the value for the dimension of the result (which in turn for the default case would call `reifyDimOfResult`). While this change sets up the interface, ideally most operations will implement the `refiyDimOfResult` when possible. For almost all operations in tree this is true. Subsequent commits will change those incrementally. Some of the tests added here that check that the default implementations for the above method work as expected, also end up hitting the pattern rewriter limit when using `--resolve-ranked-shaped-type-result-dims`/ `--resolve-ranked-shaped-type-result-dims`. For testing purposes, a flag is added to these passes that ignore the error returned by the pattern application (this flag is left on by default to maintain current state). Changes required downstream to integrate this change 1. In operation definitions in .td files, for those operations that implement the `ReifyRankedShapedTypeOpInterface`. ``` def <op-name> : Op<..., [..., DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface]]> ``` should be changed to ``` def <op-name> : Op<..., [..., DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface, [ "reifyResultShapes"]]]> ``` --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent fa98fcd commit fc093f1

File tree

16 files changed

+459
-38
lines changed

16 files changed

+459
-38
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []>
2828

2929
def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
3030
[AttrSizedOperandSegments, BufferizableOpInterface,
31-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
31+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
32+
"reifyResultShapes"]>]> {
3233
let summary = "allocate buffer for a tensor";
3334

3435
let description = [{
@@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
219220
: Bufferization_Op<"materialize_in_destination",
220221
[AllElementTypesMatch<["source", "dest"]>,
221222
BufferizableOpInterface, DestinationStyleOpInterface,
222-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
223+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
224+
"reifyResultShapes"]>,
223225
DeclareOpInterfaceMethods<SubsetOpInterface,
224226
["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
225227
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
9494
def Linalg_SoftmaxOp : Linalg_Op<"softmax",
9595
[DestinationStyleOpInterface,
9696
PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
97-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
97+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
98+
["reifyResultShapes"]>,
9899
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
99100
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
100101
DeclareOpInterfaceMethods<TilingInterface,

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
3535
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
3636
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
3737
ConditionallySpeculatable, NoMemoryEffect,
38-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
38+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
39+
"reifyResultShapes"]>,
3940
TypesMatchWith<"result type matches type of dest",
4041
"dest", "result",
4142
"$_self">])> {

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1783,7 +1783,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
17831783
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
17841784
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
17851785
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1786-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
1786+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
1787+
["reifyResultShapes"]>]> {
17871788
let summary = "operation to produce a memref with a higher rank.";
17881789
let description = [{
17891790
The `memref.expand_shape` op produces a new view with a higher rank whose

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ def ResolveRankedShapeTypeResultDimsPass
164164
implement the `ReifyRankedShapedTypeOpInterface` in terms of
165165
shapes of its operands.
166166
}];
167+
let options = [
168+
Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
169+
/*default=*/"true",
170+
"Throw an error when pattern rewriter hits iteration limit">,
171+
];
167172
let dependentDialects = [
168173
"memref::MemRefDialect", "tensor::TensorDialect"
169174
];
@@ -177,6 +182,11 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
177182
`ReifyRankedShapedTypeOpInterface` in terms of shapes of its
178183
operands.
179184
}];
185+
let options = [
186+
Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
187+
/*default=*/"true",
188+
"Throw an error when pattern rewriter hits iteration limit">,
189+
];
180190
let dependentDialects = [
181191
"affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
182192
];

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [
131131
def Tensor_ConcatOp : Tensor_Op<"concat",
132132
[Pure,
133133
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
134-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
134+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
135+
"reifyResultShapes"]>,
136+
]> {
135137
let summary = "tensor concatenation operation";
136138
let description = [{
137139
The "concat" operation constructs a tensor out of a variadic list of input
@@ -261,7 +263,8 @@ def Tensor_DimOp : Tensor_Op<"dim", [
261263

262264
def Tensor_EmptyOp : Tensor_Op<"empty",
263265
[Pure,
264-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
266+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
267+
"reifyResultShapes"]>]> {
265268
let summary = "empty tensor operation";
266269

267270
let description = [{
@@ -358,7 +361,8 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
358361

359362
def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [
360363
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
361-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
364+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
365+
"reifyResultShapes"]>,
362366
AttrSizedOperandSegments,
363367
Pure,
364368
OffsetSizeAndStrideOpInterface
@@ -740,7 +744,8 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
740744
def Tensor_GenerateOp : Tensor_Op<"generate", [
741745
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
742746
RecursiveMemoryEffects,
743-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
747+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
748+
"reifyResultShapes"]>,
744749
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
745750
let summary = "Creates a dynamically sized tensor from elements";
746751
let description = [{
@@ -835,7 +840,8 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
835840

836841
def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
837842
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
838-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
843+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
844+
"reifyResultShapes"]>,
839845
AttrSizedOperandSegments,
840846
DestinationStyleOpInterface,
841847
Pure,
@@ -1256,7 +1262,8 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
12561262

12571263
def Tensor_PadOp : Tensor_Op<"pad", [
12581264
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1259-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
1265+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
1266+
"reifyResultShapes"]>,
12601267
AttrSizedOperandSegments,
12611268
Pure,
12621269
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
@@ -1764,7 +1771,8 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
17641771

17651772
def Tensor_SplatOp : Tensor_Op<"splat", [
17661773
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1767-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
1774+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
1775+
["reifyResultShapes"]>,
17681776
Pure,
17691777
TypesMatchWith<"operand type matches element type of result",
17701778
"aggregate", "input",

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2219,7 +2219,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
22192219
// Operator: transpose
22202220
//===----------------------------------------------------------------------===//
22212221
def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
2222-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
2222+
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface ,
2223+
["reifyResultShapes"]>,
22232224
AllElementTypesMatch<["input1", "output"]>]> {
22242225
let summary = "Transpose operator.";
22252226

mlir/include/mlir/Interfaces/InferTypeOpInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
3333
LogicalResult
3434
reifyResultShapes(OpBuilder &b, Operation *op,
3535
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
36+
FailureOr<SmallVector<OpFoldResult>>
37+
reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex);
38+
FailureOr<OpFoldResult> reifyDimOfResult(OpBuilder &b, Operation *op,
39+
int resultIndex, int dim);
3640

3741
/// Adaptor class to abstract the differences between whether value is from
3842
/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.

mlir/include/mlir/Interfaces/InferTypeOpInterface.td

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface :
361361
let methods = [
362362
InterfaceMethod<
363363
/*desc=*/[{
364-
Reify the shape of the result of an operation (typically in terms of the
365-
shape of its operands).
364+
Reify the shapes of all the result of an operation (typically in terms
365+
of the shape of its operands).
366366

367367
`reifiedReturnShapes` is populated with one vector per op result. Each
368368
of those vectors contains an OpFoldResult for each dimension of the
369369
shaped type. The given builder may be used to insert ops that compute
370370
result shapes.
371371

372-
If the shape of a particular result cannot be computed it must be empty.
372+
If the shape of a particular result cannot be computed it in terms of
373+
its operands it must be left empty. If any dimension of the result cannot
374+
be computed it must be set to OpFoldResult().
373375
}],
374376
/*retTy=*/"::llvm::LogicalResult",
375377
/*methodName=*/"reifyResultShapes",
376378
/*args=*/(ins "::mlir::OpBuilder &":$builder,
377-
"::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
379+
"::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes),
380+
/*methodBody=*/"",
381+
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
382+
>,
383+
InterfaceMethod<
384+
/*desc=*/[{
385+
Reify the shape of a single result of an operation (typically in terms
386+
of the shape of its operands).
387+
388+
Returns the shape of a single result of the operation as a
389+
`SmallVector<OpFoldResult>`, one per dimension of the shaped type. The
390+
given builder may be used to insert ops that compute result shapes.
391+
392+
If any dimension of the result cannot be computed it must be set to
393+
OpFoldResult().
394+
}],
395+
/*retTy=*/"::llvm::FailureOr<::llvm::SmallVector<::mlir::OpFoldResult>>",
396+
/*methodName=*/"reifyShapeOfResult",
397+
/*args=*/(ins "::mlir::OpBuilder &":$builder,
398+
"int":$resultIndex),
399+
/*methodBody=*/"",
400+
/*defaultImplementation=*/[{
401+
ReifiedRankedShapedTypeDims reifiedShapes;
402+
if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes)))
403+
return failure();
404+
if (resultIndex < 0 || resultIndex >= static_cast<int>(reifiedShapes.size()))
405+
return $_op.emitOpError("invalid result index");
406+
return reifiedShapes[resultIndex];
407+
}]
408+
>,
409+
InterfaceMethod<
410+
/*desc=*/[{
411+
Reify the shape of a dimension of a given result of an operation
412+
(typically in terms of the shape of its operands).
413+
414+
Returns the shape of a specific dimension of a result of the operation as
415+
an OpFoldResult. The given builder may be used to insert ops that compute
416+
the shapes.
417+
418+
If the dimension of the result cannot be computed the method must return
419+
`failure()`.
420+
}],
421+
/*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>",
422+
/*methodName=*/"reifyDimOfResult",
423+
/*args=*/(ins "::mlir::OpBuilder &":$builder,
424+
"int":$resultIndex, "int":$dim),
425+
/*methodBody=*/"",
426+
/*defaultImplementation=*/[{
427+
auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex);
428+
if (failed(shapes))
429+
return failure();
430+
if (dim < 0 || dim >= static_cast<int>((*shapes).size()))
431+
return $_op.emitOpError("invalid dimension");
432+
return (*shapes)[dim];
433+
}]
378434
>
379435
];
380436
}

mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
9090
if (!dimIndex)
9191
return failure();
9292

93-
ReifiedRankedShapedTypeDims reifiedResultShapes;
94-
if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
95-
reifiedResultShapes)))
93+
FailureOr<OpFoldResult> replacement = reifyDimOfResult(
94+
rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
95+
if (failed(replacement))
9696
return failure();
97-
unsigned resultNumber = dimValue.getResultNumber();
98-
// Do not apply pattern if the IR is invalid (dim out of bounds).
99-
if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
100-
return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
101-
Value replacement = getValueOrCreateConstantIndexOp(
102-
rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
103-
rewriter.replaceOp(dimOp, replacement);
97+
// Check if the OpFoldResult is empty (unreifiable dimension).
98+
if (!replacement.value())
99+
return failure();
100+
Value replacementVal = getValueOrCreateConstantIndexOp(
101+
rewriter, dimOp.getLoc(), replacement.value());
102+
rewriter.replaceOp(dimOp, replacementVal);
104103
return success();
105104
}
106105
};
@@ -166,12 +165,14 @@ namespace {
166165
struct ResolveRankedShapeTypeResultDimsPass final
167166
: public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
168167
ResolveRankedShapeTypeResultDimsPass> {
168+
using Base::Base;
169169
void runOnOperation() override;
170170
};
171171

172172
struct ResolveShapedTypeResultDimsPass final
173173
: public memref::impl::ResolveShapedTypeResultDimsPassBase<
174174
ResolveShapedTypeResultDimsPass> {
175+
using Base::Base;
175176
void runOnOperation() override;
176177
};
177178

@@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
195196
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
196197
RewritePatternSet patterns(&getContext());
197198
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
198-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
199+
auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
200+
if (errorOnPatternIterationLimit && failed(result)) {
201+
getOperation()->emitOpError(
202+
"dim operation resolution hit pattern iteration limit");
199203
return signalPassFailure();
204+
}
200205
}
201206

202207
void ResolveShapedTypeResultDimsPass::runOnOperation() {
203208
RewritePatternSet patterns(&getContext());
204209
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
205210
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
206-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
211+
auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
212+
if (errorOnPatternIterationLimit && failed(result)) {
213+
getOperation()->emitOpError(
214+
"dim operation resolution hit pattern iteration limit");
207215
return signalPassFailure();
216+
}
208217
}

0 commit comments

Comments
 (0)