1616using namespace mlir ;
1717using namespace mlir ::tensor;
1818
19- // / Compute a map that for a given dimension of the expanded type gives the
20- // / dimension in the collapsed type it maps to. Essentially its the inverse of
21- // / the `reassocation` maps.
22- static llvm::DenseMap<int64_t , int64_t >
23- getExpandedDimToCollapsedDimMap (ArrayRef<AffineMap> reassociation) {
24- llvm::DenseMap<int64_t , int64_t > expandedDimToCollapsedDim;
25- for (const auto &map : enumerate(reassociation)) {
26- unsigned startPos =
27- cast<AffineDimExpr>(map.value ().getResults ().front ()).getPosition ();
28- unsigned endPos =
29- cast<AffineDimExpr>(map.value ().getResults ().back ()).getPosition ();
30- for (auto dim : llvm::seq_inclusive (startPos, endPos)) {
31- expandedDimToCollapsedDim[dim] = map.index ();
32- }
33- }
34- return expandedDimToCollapsedDim;
35- }
36-
3719// / For reshape op compute the shape at dimension `dimIndex` of the output in
3820// / terms of shape of the `src`, when the reshape op is a collapsing
3921// / operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,84 +58,15 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
7658 }));
7759}
7860
79- // / For an expanding reshape op, compute the value for a dimension of the output
80- // / from the shape of the input.
81- static OpFoldResult getExpandedOutputDimFromInputShape (
82- OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
83- ArrayRef<int64_t > dstStaticShape, ArrayRef<AffineMap> reassociation,
84- llvm::DenseMap<int64_t , int64_t > &expandedDimToCollapsedDim) {
85- if (!ShapedType::isDynamic (dstStaticShape[dimIndex])) {
86- // Static dimension: return Attribute.
87- return builder.getIndexAttr (dstStaticShape[dimIndex]);
88- }
89- unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
90- unsigned startPos =
91- cast<AffineDimExpr>(reassociation[sourceDimPos].getResults ().front ())
92- .getPosition ();
93- unsigned endPos =
94- cast<AffineDimExpr>(reassociation[sourceDimPos].getResults ().back ())
95- .getPosition ();
96- int64_t linearizedStaticDim = 1 ;
97- for (auto d :
98- llvm::enumerate (dstStaticShape.slice (startPos, endPos - startPos + 1 ))) {
99- if (d.index () + startPos == static_cast <unsigned >(dimIndex))
100- continue ;
101- assert (!ShapedType::isDynamic (d.value ()) &&
102- " single dimension cannot be expanded into multiple dynamic "
103- " dimensions" );
104- linearizedStaticDim *= d.value ();
105- }
106- OpFoldResult sourceDim =
107- builder.create <tensor::DimOp>(loc, src, sourceDimPos).getResult ();
108-
109- // Dynamic dimension: return Value.
110- return affine::makeComposedAffineApply (
111- builder, loc,
112- AffineMap::get (
113- 0 , 1 ,
114- builder.getAffineSymbolExpr (0 ).floorDiv (linearizedStaticDim)),
115- sourceDim)
116- ->getResult (0 );
117- }
118-
119- // / Given the `src` of an expanding reshape op, the reassociation maps and the
120- // / result type, compute the shape of the result of the reshape.
121- static SmallVector<OpFoldResult, 4 > getExpandedOutputShapeFromInputShape (
122- OpBuilder &builder, Location loc, Value src,
123- ArrayRef<int64_t > dstStaticShape, ArrayRef<AffineMap> reassociation) {
124- llvm::DenseMap<int64_t , int64_t > expandedDimToCollapsedDim =
125- getExpandedDimToCollapsedDimMap (reassociation);
126- return llvm::to_vector<4 >(llvm::map_range (
127- llvm::seq<int64_t >(0 , dstStaticShape.size ()), [&](int64_t dim) {
128- return getExpandedOutputDimFromInputShape (builder, loc, dim, src,
129- dstStaticShape, reassociation,
130- expandedDimToCollapsedDim);
131- }));
132- }
133-
134- static SmallVector<OpFoldResult, 4 >
135- getReshapeOutputShapeFromInputShape (OpBuilder &builder, Location loc, Value src,
136- ArrayRef<int64_t > dstStaticShape,
137- ArrayRef<AffineMap> reassocation) {
138- return dstStaticShape.size () >
139- static_cast <size_t >(
140- llvm::cast<ShapedType>(src.getType ()).getRank ())
141- ? getExpandedOutputShapeFromInputShape (
142- builder, loc, src, dstStaticShape, reassocation)
143- : getCollapsedOutputShapeFromInputShape (
144- builder, loc, src, dstStaticShape, reassocation);
145- }
146-
147- template <typename OpTy>
148- struct ReifyExpandOrCollapseShapeOp
61+ struct ReifyCollapseShapeOp
14962 : public ReifyRankedShapedTypeOpInterface::ExternalModel<
150- ReifyExpandOrCollapseShapeOp<OpTy>, OpTy > {
63+ ReifyCollapseShapeOp, CollapseShapeOp > {
15164 LogicalResult
15265 reifyResultShapes (Operation *op, OpBuilder &b,
15366 ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
15467 auto loc = op->getLoc ();
155- auto reshapeOp = cast<OpTy >(op);
156- reifiedReturnShapes.push_back (getReshapeOutputShapeFromInputShape (
68+ auto reshapeOp = cast<tensor::CollapseShapeOp >(op);
69+ reifiedReturnShapes.push_back (getCollapsedOutputShapeFromInputShape (
15770 b, loc, reshapeOp.getSrc (), reshapeOp.getResultType ().getShape (),
15871 reshapeOp.getReassociationMaps ()));
15972 return success ();
@@ -162,6 +75,20 @@ struct ReifyExpandOrCollapseShapeOp
16275
16376namespace {
16477
78+ struct ReifyExpandShapeOp
79+ : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
80+ ExpandShapeOp> {
81+ LogicalResult
82+ reifyResultShapes (Operation *op, OpBuilder &b,
83+ ReifiedRankedShapedTypeDims &reifyResultShapes) const {
84+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
85+ SmallVector<OpFoldResult> resultShapes =
86+ expandShapeOp.getMixedOutputShape ();
87+ reifyResultShapes.emplace_back (std::move (resultShapes));
88+ return success ();
89+ }
90+ };
91+
16592struct ReifyPadOp
16693 : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
16794 PadOp> {
@@ -202,10 +129,8 @@ struct ReifyPadOp
202129void mlir::tensor::registerInferTypeOpInterfaceExternalModels (
203130 DialectRegistry ®istry) {
204131 registry.addExtension (+[](MLIRContext *ctx, TensorDialect *dialect) {
205- ExpandShapeOp::attachInterface<
206- ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
207- CollapseShapeOp::attachInterface<
208- ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
132+ ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
133+ CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
209134 PadOp::attachInterface<ReifyPadOp>(*ctx);
210135 });
211136}
0 commit comments