@@ -1136,23 +1136,19 @@ struct FoldReshapeWithProducerPadOpByExpansion
11361136 SmallVector<OpFoldResult> low = padOp.getMixedLowPad ();
11371137 SmallVector<OpFoldResult> high = padOp.getMixedHighPad ();
11381138
1139- for (auto [reInd, l, h] : llvm::zip_equal (reassociations, low, high)) {
1140- if (reInd.size () > 1 &&
1141- (!isConstantIntValue (l, 0 ) || !isConstantIntValue (h, 0 )))
1139+ SmallVector<OpFoldResult> newLow, newHigh;
1140+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1141+ if (reInd.size () > 1 && (!isConstantIntValue (low[idx], 0 ) ||
1142+ !isConstantIntValue (high[idx], 0 )))
11421143 return rewriter.notifyMatchFailure (
11431144 expandOp, " fusion blocked by non-zero padding" );
1144- }
11451145
1146- SmallVector<OpFoldResult> newLow, newHigh;
1147- for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
11481146 newLow.append (reInd.size (), low[idx]);
11491147 newHigh.append (reInd.size (), high[idx]);
11501148 }
11511149
11521150 Location loc = expandOp.getLoc ();
1153- ArrayRef<int64_t > finalShape = expandOp.getResultType ().getShape ();
11541151 SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape ();
1155-
11561152 for (auto [inDimIdx, reInd] : llvm::enumerate (reassociations)) {
11571153 OpFoldResult l = low[inDimIdx];
11581154 OpFoldResult h = high[inDimIdx];
@@ -1165,15 +1161,6 @@ struct FoldReshapeWithProducerPadOpByExpansion
11651161 }
11661162 }
11671163
1168- for (auto [outDimIdx, dimSize] : llvm::enumerate (finalShape)) {
1169- if (dimSize == ShapedType::kDynamic &&
1170- !isa<Value>(expandedShape[outDimIdx]) &&
1171- !isa<Attribute>(expandedShape[outDimIdx])) {
1172- expandedShape[outDimIdx] =
1173- tensor::getMixedSize (rewriter, loc, expandOp.getSrc (), outDimIdx);
1174- }
1175- }
1176-
11771164 SmallVector<int64_t > staticExpandedShape;
11781165 std::tie (staticExpandedShape, std::ignore) =
11791166 decomposeMixedValues (expandedShape);
0 commit comments