@@ -98,6 +98,27 @@ struct RankOpInterface
9898 }
9999};
100100
101+ struct CollapseShapeOpInterface
102+ : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
103+ memref::CollapseShapeOp> {
104+ void populateBoundsForShapedValueDim (Operation *op, Value value, int64_t dim,
105+ ValueBoundsConstraintSet &cstr) const {
106+ auto collapseOp = cast<memref::CollapseShapeOp>(op);
107+ assert (value == collapseOp.getResult () && " invalid value" );
108+
109+ // Multiply the expressions for the dimensions in the reassociation group.
110+ const ReassociationIndices &reassocIndices =
111+ collapseOp.getReassociationIndices ()[dim];
112+ AffineExpr productExpr =
113+ cstr.getExpr (collapseOp.getSrc (), reassocIndices[0 ]);
114+ for (size_t i = 1 ; i < reassocIndices.size (); ++i) {
115+ productExpr =
116+ productExpr * cstr.getExpr (collapseOp.getSrc (), reassocIndices[i]);
117+ }
118+ cstr.bound (value)[dim] == productExpr;
119+ }
120+ };
121+
101122struct SubViewOpInterface
102123 : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
103124 SubViewOp> {
@@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
134155 memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
135156 memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
136157 memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
158+ memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
159+ *ctx);
137160 memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
138161 *ctx);
139162 memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
0 commit comments