@@ -144,15 +144,14 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
144144 builder, loc, src, dstStaticShape, reassocation);
145145}
146146
147- template <typename OpTy>
148- struct ReifyExpandOrCollapseShapeOp
147+ struct ReifyCollapseShapeOp
149148 : public ReifyRankedShapedTypeOpInterface::ExternalModel<
150- ReifyExpandOrCollapseShapeOp<OpTy>, OpTy > {
149+ ReifyCollapseShapeOp, CollapseShapeOp > {
151150 LogicalResult
152151 reifyResultShapes (Operation *op, OpBuilder &b,
153152 ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
154153 auto loc = op->getLoc ();
155- auto reshapeOp = cast<OpTy >(op);
154+ auto reshapeOp = cast<tensor::CollapseShapeOp >(op);
156155 reifiedReturnShapes.push_back (getReshapeOutputShapeFromInputShape (
157156 b, loc, reshapeOp.getSrc (), reshapeOp.getResultType ().getShape (),
158157 reshapeOp.getReassociationMaps ()));
@@ -162,6 +161,20 @@ struct ReifyExpandOrCollapseShapeOp
162161
163162namespace {
164163
164+ struct ReifyExpandShapeOp
165+ : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
166+ ExpandShapeOp> {
167+ LogicalResult
168+ reifyResultShapes (Operation *op, OpBuilder &b,
169+ ReifiedRankedShapedTypeDims &reifyResultShapes) const {
170+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
171+ SmallVector<OpFoldResult> resultShapes =
172+ expandShapeOp.getMixedOutputShape ();
173+ reifyResultShapes.emplace_back (std::move (resultShapes));
174+ return success ();
175+ }
176+ };
177+
165178struct ReifyPadOp
166179 : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
167180 PadOp> {
@@ -202,10 +215,8 @@ struct ReifyPadOp
202215void mlir::tensor::registerInferTypeOpInterfaceExternalModels (
203216 DialectRegistry ®istry) {
204217 registry.addExtension (+[](MLIRContext *ctx, TensorDialect *dialect) {
205- ExpandShapeOp::attachInterface<
206- ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
207- CollapseShapeOp::attachInterface<
208- ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
218+ ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
219+ CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
209220 PadOp::attachInterface<ReifyPadOp>(*ctx);
210221 });
211222}
0 commit comments