|
16 | 16 | using namespace mlir; |
17 | 17 | using namespace mlir::tensor; |
18 | 18 |
|
19 | | -/// Compute a map that for a given dimension of the expanded type gives the |
20 | | -/// dimension in the collapsed type it maps to. Essentially its the inverse of |
21 | | -/// the `reassocation` maps. |
22 | | -static llvm::DenseMap<int64_t, int64_t> |
23 | | -getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) { |
24 | | - llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; |
25 | | - for (const auto &map : enumerate(reassociation)) { |
26 | | - unsigned startPos = |
27 | | - cast<AffineDimExpr>(map.value().getResults().front()).getPosition(); |
28 | | - unsigned endPos = |
29 | | - cast<AffineDimExpr>(map.value().getResults().back()).getPosition(); |
30 | | - for (auto dim : llvm::seq_inclusive(startPos, endPos)) { |
31 | | - expandedDimToCollapsedDim[dim] = map.index(); |
32 | | - } |
33 | | - } |
34 | | - return expandedDimToCollapsedDim; |
35 | | -} |
36 | | - |
37 | 19 | /// For reshape op compute the shape at dimension `dimIndex` of the output in |
38 | 20 | /// terms of shape of the `src`, when the reshape op is a collapsing |
39 | 21 | /// operation. It is the product of the shape of the collapsed dimensions of the |
@@ -76,86 +58,33 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape( |
76 | 58 | })); |
77 | 59 | } |
78 | 60 |
|
79 | | -/// For an expanding reshape op, compute the value for a dimension of the output |
80 | | -/// from the shape of the input. |
81 | | -static OpFoldResult getExpandedOutputDimFromInputShape( |
82 | | - OpBuilder &builder, Location loc, int64_t dimIndex, Value src, |
83 | | - ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation, |
84 | | - llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) { |
85 | | - if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { |
86 | | - // Static dimension: return Attribute. |
87 | | - return builder.getIndexAttr(dstStaticShape[dimIndex]); |
88 | | - } |
89 | | - unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; |
90 | | - unsigned startPos = |
91 | | - cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front()) |
92 | | - .getPosition(); |
93 | | - unsigned endPos = |
94 | | - cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back()) |
95 | | - .getPosition(); |
96 | | - int64_t linearizedStaticDim = 1; |
97 | | - for (auto d : |
98 | | - llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { |
99 | | - if (d.index() + startPos == static_cast<unsigned>(dimIndex)) |
100 | | - continue; |
101 | | - assert(!ShapedType::isDynamic(d.value()) && |
102 | | - "single dimension cannot be expanded into multiple dynamic " |
103 | | - "dimensions"); |
104 | | - linearizedStaticDim *= d.value(); |
| 61 | +struct ReifyCollapseShapeOp |
| 62 | + : public ReifyRankedShapedTypeOpInterface::ExternalModel< |
| 63 | + ReifyCollapseShapeOp, CollapseShapeOp> { |
| 64 | + LogicalResult |
| 65 | + reifyResultShapes(Operation *op, OpBuilder &b, |
| 66 | + ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { |
| 67 | + auto loc = op->getLoc(); |
| 68 | + auto collapseShape = cast<CollapseShapeOp>(op); |
| 69 | + reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape( |
| 70 | + b, loc, collapseShape.getSrc(), |
| 71 | + collapseShape.getResultType().getShape(), |
| 72 | + collapseShape.getReassociationMaps())); |
| 73 | + return success(); |
105 | 74 | } |
106 | | - OpFoldResult sourceDim = |
107 | | - builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult(); |
108 | | - |
109 | | - // Dynamic dimension: return Value. |
110 | | - return affine::makeComposedAffineApply( |
111 | | - builder, loc, |
112 | | - AffineMap::get( |
113 | | - 0, 1, |
114 | | - builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), |
115 | | - sourceDim) |
116 | | - ->getResult(0); |
117 | | -} |
118 | | - |
119 | | -/// Given the `src` of an expanding reshape op, the reassociation maps and the |
120 | | -/// result type, compute the shape of the result of the reshape. |
121 | | -static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape( |
122 | | - OpBuilder &builder, Location loc, Value src, |
123 | | - ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) { |
124 | | - llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim = |
125 | | - getExpandedDimToCollapsedDimMap(reassociation); |
126 | | - return llvm::to_vector<4>(llvm::map_range( |
127 | | - llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) { |
128 | | - return getExpandedOutputDimFromInputShape(builder, loc, dim, src, |
129 | | - dstStaticShape, reassociation, |
130 | | - expandedDimToCollapsedDim); |
131 | | - })); |
132 | | -} |
133 | | - |
134 | | -static SmallVector<OpFoldResult, 4> |
135 | | -getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, |
136 | | - ArrayRef<int64_t> dstStaticShape, |
137 | | - ArrayRef<AffineMap> reassocation) { |
138 | | - return dstStaticShape.size() > |
139 | | - static_cast<size_t>( |
140 | | - llvm::cast<ShapedType>(src.getType()).getRank()) |
141 | | - ? getExpandedOutputShapeFromInputShape( |
142 | | - builder, loc, src, dstStaticShape, reassocation) |
143 | | - : getCollapsedOutputShapeFromInputShape( |
144 | | - builder, loc, src, dstStaticShape, reassocation); |
145 | | -} |
| 75 | +}; |
146 | 76 |
|
147 | | -template <typename OpTy> |
148 | | -struct ReifyExpandOrCollapseShapeOp |
149 | | - : public ReifyRankedShapedTypeOpInterface::ExternalModel< |
150 | | - ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> { |
| 77 | +struct ReifyExpandShapeOp |
| 78 | + : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp, |
| 79 | + ExpandShapeOp> { |
151 | 80 | LogicalResult |
152 | 81 | reifyResultShapes(Operation *op, OpBuilder &b, |
153 | 82 | ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { |
154 | 83 | auto loc = op->getLoc(); |
155 | | - auto reshapeOp = cast<OpTy>(op); |
156 | | - reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape( |
157 | | - b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(), |
158 | | - reshapeOp.getReassociationMaps())); |
| 84 | + auto expandShape = cast<ExpandShapeOp>(op); |
| 85 | + SmallVector<OpFoldResult> outputShape = getMixedValues( |
| 86 | + expandShape.getStaticOutputShape(), expandShape.getOutputShape(), b); |
| 87 | + reifiedReturnShapes.push_back(outputShape); |
159 | 88 | return success(); |
160 | 89 | } |
161 | 90 | }; |
@@ -202,10 +131,8 @@ struct ReifyPadOp |
202 | 131 | void mlir::tensor::registerInferTypeOpInterfaceExternalModels( |
203 | 132 | DialectRegistry ®istry) { |
204 | 133 | registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { |
205 | | - ExpandShapeOp::attachInterface< |
206 | | - ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx); |
207 | | - CollapseShapeOp::attachInterface< |
208 | | - ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx); |
| 134 | + ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx); |
| 135 | + CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx); |
209 | 136 | PadOp::attachInterface<ReifyPadOp>(*ctx); |
210 | 137 | }); |
211 | 138 | } |
0 commit comments