Skip to content

Commit b9b193f

Browse files
[mlir][Interfaces] Add interface methods to allow reifying single result/single dim of result.
Current implementation of `reifyResultShapes` forces all implementations to return all dimensions of all results. This can be wasteful when you only require dimensions of one result, or a single dimension of a result. This was initially done with the restriction in mind that it might not always be possible to compute dimension of a single result or one dimension of a single result in all cases. To handle such cases - `reifyShapeOfResult` (which allows reifying dimensions of just one result), has a default implementation that calls `reifyResultShapes` and returns the dimensions of a single result. - `reifyDimOfResult` (which allows reifying a single dimension of a single result) has a default implementation that calls `reifyDimOfResult` and returns the value for the dimension of the result (which in turn for the default case would call `reifyDimOfResult`). Changes required downstream to integrate this change 1. In operation definitions in .td files, for those operations that implement the `ReifyRankedShapedTypeOpInterface`. ``` def <op-name> : Op<..., [..., DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface]]> ``` should be changed to ``` def <op-name> : Op<..., [..., DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface, [ "reifyResultShapes"]]]> ``` Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 69f9138 commit b9b193f

File tree

6 files changed

+73
-9
lines changed

6 files changed

+73
-9
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []>
2828

2929
def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
3030
[AttrSizedOperandSegments, BufferizableOpInterface,
31-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
31+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
32+
"reifyResultShapes"]>]> {
3233
let summary = "allocate buffer for a tensor";
3334

3435
let description = [{
@@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
219220
: Bufferization_Op<"materialize_in_destination",
220221
[AllElementTypesMatch<["source", "dest"]>,
221222
BufferizableOpInterface, DestinationStyleOpInterface,
222-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
223+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
224+
"reifyResultShapes"]>,
223225
DeclareOpInterfaceMethods<SubsetOpInterface,
224226
["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
225227
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
9494
def Linalg_SoftmaxOp : Linalg_Op<"softmax",
9595
[DestinationStyleOpInterface,
9696
PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
97-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
97+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
98+
["reifyResultShapes"]>,
9899
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
99100
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
100101
DeclareOpInterfaceMethods<TilingInterface,

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
3535
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
3636
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
3737
ConditionallySpeculatable, NoMemoryEffect,
38-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
38+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
39+
"reifyResultShapes"]>,
3940
TypesMatchWith<"result type matches type of dest",
4041
"dest", "result",
4142
"$_self">])> {

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1776,7 +1776,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
17761776
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
17771777
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
17781778
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1779-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
1779+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
1780+
["reifyResultShapes"]>]> {
17801781
let summary = "operation to produce a memref with a higher rank.";
17811782
let description = [{
17821783
The `memref.expand_shape` op produces a new view with a higher rank whose

mlir/include/mlir/Interfaces/InferTypeOpInterface.td

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface :
361361
let methods = [
362362
InterfaceMethod<
363363
/*desc=*/[{
364-
Reify the shape of the result of an operation (typically in terms of the
365-
shape of its operands).
364+
Reify the shapes of all the result of an operation (typically in terms
365+
of the shape of its operands).
366366

367367
`reifiedReturnShapes` is populated with one vector per op result. Each
368368
of those vectors contains an OpFoldResult for each dimension of the
369369
shaped type. The given builder may be used to insert ops that compute
370370
result shapes.
371371

372-
If the shape of a particular result cannot be computed it must be empty.
372+
If the shape of a particular result cannot be computed it in terms of
373+
its operands it must be left empty. If any dimension of the result cannot
374+
be computed it must be set to OpFoldResult().
373375
}],
374376
/*retTy=*/"::llvm::LogicalResult",
375377
/*methodName=*/"reifyResultShapes",
376378
/*args=*/(ins "::mlir::OpBuilder &":$builder,
377-
"::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
379+
"::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes),
380+
/*methodBody=*/"",
381+
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
382+
>,
383+
InterfaceMethod<
384+
/*desc=*/[{
385+
Reify the shape of a single result of an operation (typically in terms
386+
of the shape of its operands).
387+
388+
Returns the shape of a single result of the operation as a
389+
`SmallVector<OpFoldResult>`, one per dimension of the shaped type. The
390+
given builder may be used to insert ops that compute result shapes.
391+
392+
If any dimension of the result cannot be computed it must be set to
393+
OpFoldResult().
394+
}],
395+
/*retTy=*/"::llvm::FailureOr<SmallVector<::mlir::OpFoldResult>>",
396+
/*methodName=*/"reifyShapeOfResult",
397+
/*args=*/(ins "::mlir::OpBuilder &":$builder,
398+
"int":$resultIndex),
399+
/*methodBody=*/"",
400+
/*defaultImplementation=*/[{
401+
ReifiedRankedShapedTypeDims reifiedShapes;
402+
if (failed($_op.reifyResultShapes(builder, reifiedShapes)))
403+
return failure();
404+
if (resultIndex < 0 || resultIndex >= (int)(reifiedShapes.size()))
405+
return $_op.emitOpError("invalid result index");
406+
return reifiedShapes[resultIndex];
407+
}]
408+
>,
409+
InterfaceMethod<
410+
/*desc=*/[{
411+
Reify the shape of a dimension of a given result of an operation
412+
(typically in terms of the shape of its operands).
413+
414+
Returns the shape of a specific dimension of a result of the operation as
415+
an OpFoldResult. The given builder may be used to insert ops that compute
416+
the shapes.
417+
418+
If the dimension of the result cannot be computed the method must return
419+
`failure()`.
420+
}],
421+
/*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>",
422+
/*methodName=*/"reifyDimOfResult",
423+
/*args=*/(ins "::mlir::OpBuilder &":$builder,
424+
"int":$resultIndex, "int":$dim),
425+
/*methodBody=*/"",
426+
/*defaultImplementation=*/[{
427+
auto shapes = $_op.reifyShapeOfResult(builder, resultIndex);
428+
if (failed(shapes))
429+
return failure();
430+
if (dim < 0 || dim >= (int)((*shapes).size()))
431+
return $_op.emitOpError("invalid dimension");
432+
return (*shapes)[dim];
433+
}]
378434
>
379435
];
380436
}

mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ namespace {
7777
struct ReifyExpandShapeOp
7878
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
7979
ExpandShapeOp> {
80+
using Base =
81+
ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
82+
ExpandShapeOp>;
8083
LogicalResult
8184
reifyResultShapes(Operation *op, OpBuilder &b,
8285
ReifiedRankedShapedTypeDims &reifyResultShapes) const {

0 commit comments

Comments
 (0)