@@ -1239,32 +1239,21 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12391239 auto resultType = collapseOp.getResultType ();
12401240
12411241 ArrayRef<int64_t > finalShape = collapseOp.getResultType ().getShape ();
1242-
1243- SmallVector<OpFoldResult> collapsedShape;
1244- for (int64_t dimSize : finalShape) {
1245- if (dimSize == ShapedType::kDynamic ) {
1246- collapsedShape.push_back (OpFoldResult{});
1247- } else {
1248- collapsedShape.push_back (rewriter.getI64IntegerAttr (dimSize));
1249- }
1250- }
1251-
1242+ SmallVector<int64_t > collapsedShape (finalShape.begin (), finalShape.end ());
12521243 for (auto [inDimIdx, reInd] : llvm::enumerate (reassociations)) {
12531244 OpFoldResult l = low[reInd[0 ]];
12541245 OpFoldResult h = high[reInd[0 ]];
1255-
12561246 if (!isConstantIntValue (l, 0 ) || !isConstantIntValue (h, 0 )) {
1257- collapsedShape[inDimIdx] =
1247+ auto mixedSize =
12581248 tensor::getMixedSize (rewriter, loc, padOp.getSource (), reInd[0 ]);
1249+ auto dimSize = getConstantIntValue (mixedSize);
1250+ assert (dimSize.has_value () && " Expected static dimension" );
1251+ collapsedShape[inDimIdx] = *dimSize;
12591252 }
12601253 }
12611254
1262- SmallVector<int64_t > staticCollapsedShape;
1263- std::tie (staticCollapsedShape, std::ignore) =
1264- decomposeMixedValues (collapsedShape);
1265-
12661255 auto newCollapseType = RankedTensorType::get (
1267- staticCollapsedShape , padOp.getSource ().getType ().getElementType ());
1256+ collapsedShape , padOp.getSource ().getType ().getElementType ());
12681257 auto newCollapseOp = rewriter.create <tensor::CollapseShapeOp>(
12691258 loc, newCollapseType, padOp.getSource (), reassociations);
12701259
0 commit comments