@@ -1217,20 +1217,15 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12171217 collapseOp.getReassociationIndices ();
12181218 SmallVector<OpFoldResult> low = padOp.getMixedLowPad ();
12191219 SmallVector<OpFoldResult> high = padOp.getMixedHighPad ();
1220-
1220+ SmallVector<OpFoldResult> newLow, newHigh;
12211221 for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1222- if (reInd.size () > 1 ) {
1223- for (auto dimIdx : reInd) {
1224- if (!isConstantIntValue (low[dimIdx], 0 ) ||
1225- !isConstantIntValue (high[dimIdx], 0 )) {
1226- return failure ();
1227- }
1228- }
1222+ if (reInd.size () > 1 && llvm::any_of (reInd, [&](int64_t dimIdx) {
1223+ return !isConstantIntValue (low[dimIdx], 0 ) ||
1224+ !isConstantIntValue (high[dimIdx], 0 );
1225+ })) {
1226+ return failure ();
12291227 }
1230- }
12311228
1232- SmallVector<OpFoldResult> newLow, newHigh;
1233- for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
12341229 newLow.push_back (low[reInd[0 ]]);
12351230 newHigh.push_back (high[reInd[0 ]]);
12361231 }
@@ -1244,16 +1239,12 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12441239 OpFoldResult l = low[reInd[0 ]];
12451240 OpFoldResult h = high[reInd[0 ]];
12461241 if (!isConstantIntValue (l, 0 ) || !isConstantIntValue (h, 0 )) {
1247- auto mixedSize =
1248- tensor::getMixedSize (rewriter, loc, padOp.getSource (), reInd[0 ]);
1249- auto dimSize = getConstantIntValue (mixedSize);
1250- assert (dimSize.has_value () && " Expected static dimension" );
1251- collapsedShape[inDimIdx] = *dimSize;
1242+ collapsedShape[inDimIdx] = padOp.getSourceType ().getShape ()[reInd[0 ]];
12521243 }
12531244 }
12541245
12551246 auto newCollapseType = RankedTensorType::get (
1256- collapsedShape, padOp.getSource (). getType ().getElementType ());
1247+ collapsedShape, padOp.getSourceType ().getElementType ());
12571248 auto newCollapseOp = rewriter.create <tensor::CollapseShapeOp>(
12581249 loc, newCollapseType, padOp.getSource (), reassociations);
12591250
0 commit comments