Skip to content

Commit c7c745a

Browse files
matthias-springermemfrob
authored andcommitted
[mlir][tensor] Add tensor.dim operation
* Split memref.dim into two operations: memref.dim and tensor.dim. Both ops have the same builder interface and op argument names, so that they can be used with templates in patterns that apply to both tensors and memrefs (e.g., some patterns in Linalg). * Add constant materializer to TensorDialect (needed for folding in affine.apply etc.). * Remove some MemRefDialect dependencies, make some explicit. Differential Revision: https://reviews.llvm.org/D105165
1 parent 7c2669e commit c7c745a

File tree

68 files changed

+670
-499
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+670
-499
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,6 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
365365
"dialect";
366366
let constructor = "mlir::createConvertShapeToStandardPass()";
367367
let dependentDialects = [
368-
"memref::MemRefDialect",
369368
"StandardOpsDialect",
370369
"scf::SCFDialect",
371370
"tensor::TensorDialect"

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3434
/// Given an operation, retrieves the value of each dynamic dimension through
3535
/// constructing the necessary DimOp operators.
3636
SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b);
37+
38+
// Helper function that creates a memref::DimOp or tensor::DimOp depending on
39+
// the type of `source`.
40+
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
3741
} // namespace mlir
3842

3943
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
512512
// DimOp
513513
//===----------------------------------------------------------------------===//
514514

515-
def DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
515+
def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
516516
let summary = "dimension index operation";
517517
let description = [{
518518
The `dim` operation takes a memref and a dimension operand of type `index`.
@@ -538,18 +538,17 @@ def DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
538538
```
539539
}];
540540

541-
let arguments = (ins AnyTypeOf<[AnyTensor, AnyRankedOrUnrankedMemRef],
542-
"any memref or tensor type">:$memrefOrTensor,
541+
let arguments = (ins AnyRankedOrUnrankedMemRef:$source,
543542
Index:$index);
544543
let results = (outs Index:$result);
545544

546545
let assemblyFormat = [{
547-
attr-dict $memrefOrTensor `,` $index `:` type($memrefOrTensor)
546+
attr-dict $source `,` $index `:` type($source)
548547
}];
549548

550549
let builders = [
551-
OpBuilder<(ins "Value":$memrefOrTensor, "int64_t":$index)>,
552-
OpBuilder<(ins "Value":$memrefOrTensor, "Value":$index)>
550+
OpBuilder<(ins "Value":$source, "int64_t":$index)>,
551+
OpBuilder<(ins "Value":$source, "Value":$index)>
553552
];
554553

555554
let extraClassDeclaration = [{
@@ -1288,6 +1287,7 @@ def TensorLoadOp : MemRef_Op<"tensor_load",
12881287

12891288
let assemblyFormat = "$memref attr-dict `:` type($memref)";
12901289

1290+
let hasCanonicalizer = 1;
12911291
let hasFolder = 1;
12921292
}
12931293

mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ include "mlir/Pass/PassBase.td"
1414
def StdBufferize : FunctionPass<"std-bufferize"> {
1515
let summary = "Bufferize the std dialect";
1616
let constructor = "mlir::createStdBufferizePass()";
17-
let dependentDialects = ["scf::SCFDialect"];
17+
let dependentDialects = ["memref::MemRefDialect", "scf::SCFDialect"];
1818
}
1919

2020
def StdExpandOps : FunctionPass<"std-expand"> {

mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "mlir/IR/OpBase.td"
1414
def Tensor_Dialect : Dialect {
1515
let name = "tensor";
1616
let cppNamespace = "::mlir::tensor";
17+
1718
let description = [{
1819
The `tensor` dialect is intended to hold core tensor creation and
1920
manipulation ops, which are not strongly associated with any particular
@@ -43,6 +44,8 @@ def Tensor_Dialect : Dialect {
4344
dialect), and does not live in this dialect.
4445

4546
}];
47+
48+
let hasConstantMaterializer = 1;
4649
}
4750

4851
#endif // TENSOR_BASE

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,58 @@ def Tensor_CastOp : Tensor_Op<"cast", [
6060
let verifier = ?;
6161
}
6262

63+
//===----------------------------------------------------------------------===//
64+
// DimOp
65+
//===----------------------------------------------------------------------===//
66+
67+
def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> {
68+
let summary = "dimension index operation";
69+
let description = [{
70+
The `dim` operation takes a tensor and a dimension operand of type `index`.
71+
It returns the size of the requested dimension of the given tensor.
72+
If the dimension index is out of bounds, the behavior is undefined.
73+
74+
The specified tensor type is that of the first operand.
75+
76+
Example:
77+
78+
```mlir
79+
// Always returns 4, can be constant folded:
80+
%c0 = constant 0 : index
81+
%x = tensor.dim %A, %c0 : tensor<4x?xf32>
82+
83+
// Returns the dynamic dimension of %A.
84+
%c1 = constant 1 : index
85+
%y = tensor.dim %A, %c1 : memref<4x?xf32>
86+
87+
// Equivalent generic form:
88+
%x = "tensor.dim"(%A, %c0) : (memref<4x?xf32>, index) -> index
89+
%y = "tensor.dim"(%A, %c1) : (memref<4x?xf32>, index) -> index
90+
```
91+
}];
92+
93+
let arguments = (ins AnyTensor:$source,
94+
Index:$index);
95+
let results = (outs Index:$result);
96+
97+
let assemblyFormat = [{
98+
attr-dict $source `,` $index `:` type($source)
99+
}];
100+
101+
let builders = [
102+
OpBuilder<(ins "Value":$source, "int64_t":$index)>,
103+
OpBuilder<(ins "Value":$source, "Value":$index)>
104+
];
105+
106+
let extraClassDeclaration = [{
107+
/// Helper function to get the index as a simple integer if it is constant.
108+
Optional<int64_t> getConstantIndex();
109+
}];
110+
111+
let hasCanonicalizer = 1;
112+
let hasFolder = 1;
113+
}
114+
63115
//===----------------------------------------------------------------------===//
64116
// ExtractOp
65117
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ include "mlir/Pass/PassBase.td"
1414
def TensorBufferize : FunctionPass<"tensor-bufferize"> {
1515
let summary = "Bufferize the `tensor` dialect";
1616
let constructor = "mlir::createTensorBufferizePass()";
17-
let dependentDialects = ["scf::SCFDialect"];
17+
let dependentDialects = ["scf::SCFDialect", "memref::MemRefDialect"];
1818
}
1919

2020
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES

mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ add_mlir_conversion_library(MLIRShapeToStandard
1818

1919
LINK_LIBS PUBLIC
2020
MLIRIR
21-
MLIRMemRef
2221
MLIRShape
2322
MLIRTensor
2423
MLIRPass

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
1010

1111
#include "../PassDetail.h"
12-
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1312
#include "mlir/Dialect/SCF/SCF.h"
1413
#include "mlir/Dialect/Shape/IR/Shape.h"
1514
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -140,7 +139,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
140139
// dimension in the tensor.
141140
SmallVector<Value> ranks, rankDiffs;
142141
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
143-
return lb.create<memref::DimOp>(v, zero);
142+
return lb.create<tensor::DimOp>(v, zero);
144143
}));
145144

146145
// Find the maximum rank
@@ -254,7 +253,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
254253
// dimension in the tensor.
255254
SmallVector<Value> ranks, rankDiffs;
256255
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
257-
return lb.create<memref::DimOp>(v, zero);
256+
return lb.create<tensor::DimOp>(v, zero);
258257
}));
259258

260259
// Find the maximum rank
@@ -346,7 +345,7 @@ LogicalResult GetExtentOpConverter::matchAndRewrite(
346345
// circumvents the necessity to materialize the shape in memory.
347346
if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
348347
if (shapeOfOp.arg().getType().isa<ShapedType>()) {
349-
rewriter.replaceOpWithNewOp<memref::DimOp>(op, shapeOfOp.arg(),
348+
rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.arg(),
350349
transformed.dim());
351350
return success();
352351
}
@@ -377,7 +376,7 @@ RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
377376
return failure();
378377

379378
shape::RankOp::Adaptor transformed(operands);
380-
rewriter.replaceOpWithNewOp<memref::DimOp>(op, transformed.shape(), 0);
379+
rewriter.replaceOpWithNewOp<tensor::DimOp>(op, transformed.shape(), 0);
381380
return success();
382381
}
383382

@@ -407,7 +406,7 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
407406
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
408407
Type indexTy = rewriter.getIndexType();
409408
Value rank =
410-
rewriter.create<memref::DimOp>(loc, indexTy, transformed.shape(), zero);
409+
rewriter.create<tensor::DimOp>(loc, indexTy, transformed.shape(), zero);
411410

412411
auto loop = rewriter.create<scf::ForOp>(
413412
loc, zero, rank, one, op.initVals(),
@@ -494,11 +493,11 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
494493
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
495494
Value firstShape = transformed.shapes().front();
496495
Value firstRank =
497-
rewriter.create<memref::DimOp>(loc, indexTy, firstShape, zero);
496+
rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
498497
Value result = nullptr;
499498
// Generate a linear sequence of compares, all with firstShape as lhs.
500499
for (Value shape : transformed.shapes().drop_front(1)) {
501-
Value rank = rewriter.create<memref::DimOp>(loc, indexTy, shape, zero);
500+
Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
502501
Value eqRank =
503502
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
504503
auto same = rewriter.create<IfOp>(
@@ -563,7 +562,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
563562
int64_t rank = rankedTensorTy.getRank();
564563
for (int64_t i = 0; i < rank; i++) {
565564
if (rankedTensorTy.isDynamicDim(i)) {
566-
Value extent = rewriter.create<memref::DimOp>(loc, tensor, i);
565+
Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
567566
extentValues.push_back(extent);
568567
} else {
569568
Value extent =
@@ -587,7 +586,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
587586
op, getExtentTensorType(ctx), ValueRange{rank},
588587
[&](OpBuilder &b, Location loc, ValueRange args) {
589588
Value dim = args.front();
590-
Value extent = b.create<memref::DimOp>(loc, tensor, dim);
589+
Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
591590
b.create<tensor::YieldOp>(loc, extent);
592591
});
593592

@@ -617,7 +616,7 @@ LogicalResult SplitAtOpConversion::matchAndRewrite(
617616
SplitAtOp::Adaptor transformed(op);
618617
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
619618
Value zero = b.create<ConstantIndexOp>(0);
620-
Value rank = b.create<memref::DimOp>(transformed.operand(), zero);
619+
Value rank = b.create<tensor::DimOp>(transformed.operand(), zero);
621620

622621
// index < 0 ? index + rank : index
623622
Value originalIndex = transformed.index();
@@ -675,8 +674,8 @@ void ConvertShapeToStandardPass::runOnOperation() {
675674
// Setup target legality.
676675
MLIRContext &ctx = getContext();
677676
ConversionTarget target(ctx);
678-
target.addLegalDialect<memref::MemRefDialect, StandardOpsDialect, SCFDialect,
679-
tensor::TensorDialect>();
677+
target
678+
.addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
680679
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
681680

682681
// Setup conversion patterns.

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2965,7 +2965,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
29652965
LogicalResult
29662966
matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands,
29672967
ConversionPatternRewriter &rewriter) const override {
2968-
Type operandType = dimOp.memrefOrTensor().getType();
2968+
Type operandType = dimOp.source().getType();
29692969
if (operandType.isa<UnrankedMemRefType>()) {
29702970
rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
29712971
operandType, dimOp, operands, rewriter)});
@@ -2977,7 +2977,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
29772977
operandType, dimOp, operands, rewriter)});
29782978
return success();
29792979
}
2980-
return failure();
2980+
llvm_unreachable("expected MemRefType or UnrankedMemRefType");
29812981
}
29822982

29832983
private:
@@ -2995,7 +2995,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
29952995
// Extract pointer to the underlying ranked descriptor and bitcast it to a
29962996
// memref<element_type> descriptor pointer to minimize the number of GEP
29972997
// operations.
2998-
UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor());
2998+
UnrankedMemRefDescriptor unrankedDesc(transformed.source());
29992999
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
30003000
Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
30013001
loc,
@@ -3033,7 +3033,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
30333033
int64_t i = index.getValue();
30343034
if (memRefType.isDynamicDim(i)) {
30353035
// extract dynamic size from the memref descriptor.
3036-
MemRefDescriptor descriptor(transformed.memrefOrTensor());
3036+
MemRefDescriptor descriptor(transformed.source());
30373037
return descriptor.size(rewriter, loc, i);
30383038
}
30393039
// Use constant for static size.
@@ -3042,7 +3042,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
30423042
}
30433043
Value index = dimOp.index();
30443044
int64_t rank = memRefType.getRank();
3045-
MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor());
3045+
MemRefDescriptor memrefDescriptor(transformed.source());
30463046
return memrefDescriptor.size(rewriter, loc, index, rank);
30473047
}
30483048
};

0 commit comments

Comments
 (0)