Skip to content

Commit ba624b7

Browse files
committed
[mlir][Tensor] Use output_shape for ExpandShapeOp type inference
1 parent 017c75b commit ba624b7

File tree

3 files changed

+27
-107
lines changed

3 files changed

+27
-107
lines changed

mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp

Lines changed: 23 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,6 @@
1616
using namespace mlir;
1717
using namespace mlir::tensor;
1818

19-
/// Compute a map that for a given dimension of the expanded type gives the
20-
/// dimension in the collapsed type it maps to. Essentially its the inverse of
21-
/// the `reassocation` maps.
22-
static llvm::DenseMap<int64_t, int64_t>
23-
getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
24-
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
25-
for (const auto &map : enumerate(reassociation)) {
26-
unsigned startPos =
27-
cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
28-
unsigned endPos =
29-
cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
30-
for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
31-
expandedDimToCollapsedDim[dim] = map.index();
32-
}
33-
}
34-
return expandedDimToCollapsedDim;
35-
}
36-
3719
/// For reshape op compute the shape at dimension `dimIndex` of the output in
3820
/// terms of shape of the `src`, when the reshape op is a collapsing
3921
/// operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,86 +58,33 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
7658
}));
7759
}
7860

79-
/// For an expanding reshape op, compute the value for a dimension of the output
80-
/// from the shape of the input.
81-
static OpFoldResult getExpandedOutputDimFromInputShape(
82-
OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
83-
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
84-
llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
85-
if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
86-
// Static dimension: return Attribute.
87-
return builder.getIndexAttr(dstStaticShape[dimIndex]);
88-
}
89-
unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
90-
unsigned startPos =
91-
cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
92-
.getPosition();
93-
unsigned endPos =
94-
cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
95-
.getPosition();
96-
int64_t linearizedStaticDim = 1;
97-
for (auto d :
98-
llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
99-
if (d.index() + startPos == static_cast<unsigned>(dimIndex))
100-
continue;
101-
assert(!ShapedType::isDynamic(d.value()) &&
102-
"single dimension cannot be expanded into multiple dynamic "
103-
"dimensions");
104-
linearizedStaticDim *= d.value();
61+
struct ReifyCollapseShapeOp
62+
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
63+
ReifyCollapseShapeOp, CollapseShapeOp> {
64+
LogicalResult
65+
reifyResultShapes(Operation *op, OpBuilder &b,
66+
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
67+
auto loc = op->getLoc();
68+
auto collapseShape = cast<CollapseShapeOp>(op);
69+
reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
70+
b, loc, collapseShape.getSrc(),
71+
collapseShape.getResultType().getShape(),
72+
collapseShape.getReassociationMaps()));
73+
return success();
10574
}
106-
OpFoldResult sourceDim =
107-
builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
108-
109-
// Dynamic dimension: return Value.
110-
return affine::makeComposedAffineApply(
111-
builder, loc,
112-
AffineMap::get(
113-
0, 1,
114-
builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
115-
sourceDim)
116-
->getResult(0);
117-
}
118-
119-
/// Given the `src` of an expanding reshape op, the reassociation maps and the
120-
/// result type, compute the shape of the result of the reshape.
121-
static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
122-
OpBuilder &builder, Location loc, Value src,
123-
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
124-
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
125-
getExpandedDimToCollapsedDimMap(reassociation);
126-
return llvm::to_vector<4>(llvm::map_range(
127-
llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
128-
return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
129-
dstStaticShape, reassociation,
130-
expandedDimToCollapsedDim);
131-
}));
132-
}
133-
134-
static SmallVector<OpFoldResult, 4>
135-
getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
136-
ArrayRef<int64_t> dstStaticShape,
137-
ArrayRef<AffineMap> reassocation) {
138-
return dstStaticShape.size() >
139-
static_cast<size_t>(
140-
llvm::cast<ShapedType>(src.getType()).getRank())
141-
? getExpandedOutputShapeFromInputShape(
142-
builder, loc, src, dstStaticShape, reassocation)
143-
: getCollapsedOutputShapeFromInputShape(
144-
builder, loc, src, dstStaticShape, reassocation);
145-
}
75+
};
14676

147-
template <typename OpTy>
148-
struct ReifyExpandOrCollapseShapeOp
149-
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
150-
ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
77+
struct ReifyExpandShapeOp
78+
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
79+
ExpandShapeOp> {
15180
LogicalResult
15281
reifyResultShapes(Operation *op, OpBuilder &b,
15382
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
15483
auto loc = op->getLoc();
155-
auto reshapeOp = cast<OpTy>(op);
156-
reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
157-
b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
158-
reshapeOp.getReassociationMaps()));
84+
auto expandShape = cast<ExpandShapeOp>(op);
85+
SmallVector<OpFoldResult> outputShape = getMixedValues(
86+
expandShape.getStaticOutputShape(), expandShape.getOutputShape(), b);
87+
reifiedReturnShapes.push_back(outputShape);
15988
return success();
16089
}
16190
};
@@ -202,10 +131,8 @@ struct ReifyPadOp
202131
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
203132
DialectRegistry &registry) {
204133
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
205-
ExpandShapeOp::attachInterface<
206-
ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
207-
CollapseShapeOp::attachInterface<
208-
ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
134+
ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
135+
CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
209136
PadOp::attachInterface<ReifyPadOp>(*ctx);
210137
});
211138
}

mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
210210
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
211211
return %1, %2, %3 : index, index, index
212212
}
213-
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
214213
// CHECK: func @dim_reshape_expansion
215214
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
216-
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
215+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
217216
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
218217
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
219-
// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
220-
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
221-
// CHECK: return %[[C3]], %[[C4]], %[[D1]]
218+
// CHECK: return %[[C3]], %[[C4]], %[[ARG1]]
222219

223220
// -----
224221

mlir/test/Dialect/Tensor/fold-empty-op.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
1010
}
1111
}
1212

13-
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
1413
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
1514

1615
func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,8 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
1918
return %1 : tensor<2x3x5x4x?x7xf32>
2019
}
2120
// CHECK-LABEL: func @empty_reshape_expansion
22-
// CHECK-SAME: %[[ARG0:.+]]: index
23-
// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
24-
// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
25-
// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
26-
// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
21+
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index
22+
// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]])
2723
// CHECK-NEXT: return %[[INIT]]
2824

2925
func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {

0 commit comments

Comments
 (0)