@@ -1102,20 +1102,6 @@ class FoldPadWithProducerReshapeOpByExpansion
11021102 ControlFusionFn controlFoldingReshapes;
11031103};
11041104
1105- bool isZero (OpFoldResult value) {
1106- if (auto attr = dyn_cast<Attribute>(value)) {
1107- if (auto intAttr = dyn_cast<IntegerAttr>(attr))
1108- return intAttr.getInt () == 0 ;
1109- }
1110- if (auto val = dyn_cast<Value>(value)) {
1111- if (auto constOp = val.getDefiningOp <arith::ConstantOp>()) {
1112- if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue ()))
1113- return attr.getInt () == 0 ;
1114- }
1115- }
1116- return false ;
1117- }
1118-
11191105// / Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
11201106// / by bubbling the expand_shape before the pad.
11211107struct FoldReshapeWithProducerPadOpByExpansion
@@ -1152,19 +1138,17 @@ struct FoldReshapeWithProducerPadOpByExpansion
11521138 SmallVector<OpFoldResult> low = padOp.getMixedLowPad ();
11531139 SmallVector<OpFoldResult> high = padOp.getMixedHighPad ();
11541140
1155- for (auto [idx, reInd ] : llvm::enumerate (reassociations)) {
1156- OpFoldResult l = low[idx];
1157- OpFoldResult h = high[idx];
1158- if (reInd. size () > 1 && (! isZero (l) || ! isZero (h)))
1159- return failure ( );
1141+ for (auto [reInd, l, h ] : llvm::zip_equal (reassociations, low, high )) {
1142+ if (reInd. size () > 1 &&
1143+ (! isConstantIntValue (l, 0 ) || ! isConstantIntValue (h, 0 )))
1144+ return rewriter. notifyMatchFailure (
1145+ expandOp, " fusion blocked by non-zero padding " );
11601146 }
11611147
11621148 SmallVector<OpFoldResult> newLow, newHigh;
11631149 for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1164- for (size_t i = 0 ; i < reInd.size (); ++i) {
1165- newLow.push_back (low[idx]);
1166- newHigh.push_back (high[idx]);
1167- }
1150+ newLow.append (reInd.size (), low[idx]);
1151+ newHigh.append (reInd.size (), high[idx]);
11681152 }
11691153
11701154 Location loc = expandOp.getLoc ();
@@ -1184,7 +1168,7 @@ struct FoldReshapeWithProducerPadOpByExpansion
11841168 OpFoldResult l = low[inDimIdx];
11851169 OpFoldResult h = high[inDimIdx];
11861170
1187- if (!isZero (l ) || !isZero (h )) {
1171+ if (!isConstantIntValue (l, 0 ) || !isConstantIntValue (h, 0 )) {
11881172 auto srcType = cast<RankedTensorType>(padOp.getSource ().getType ());
11891173 int64_t originalSize = srcType.getDimSize (inDimIdx);
11901174
@@ -1196,47 +1180,33 @@ struct FoldReshapeWithProducerPadOpByExpansion
11961180 } else {
11971181 originalSizeOFR = rewriter.getI64IntegerAttr (originalSize);
11981182 }
1199-
1200- for (auto outDimIdx : reInd) {
1201- expandedShape[outDimIdx] = originalSizeOFR;
1202- }
1183+ assert (reInd.size () == 1 && " expected single dimension" );
1184+ expandedShape[reInd[0 ]] = originalSizeOFR;
12031185 }
12041186 }
12051187
12061188 for (auto [outDimIdx, dimSize] : llvm::enumerate (finalShape)) {
12071189 if (dimSize == ShapedType::kDynamic &&
12081190 !isa<Value>(expandedShape[outDimIdx]) &&
12091191 !isa<Attribute>(expandedShape[outDimIdx])) {
1210- Value actualSize =
1211- rewriter.create <tensor::DimOp>(loc, expandOp.getSrc (), outDimIdx);
1212- expandedShape[outDimIdx] = actualSize;
1192+ expandedShape[outDimIdx] =
1193+ tensor::getMixedSize (rewriter, loc, expandOp.getSrc (), outDimIdx);
12131194 }
12141195 }
12151196
12161197 SmallVector<int64_t > staticExpandedShape;
1217- for (OpFoldResult dim : expandedShape) {
1218- if (auto attr = dyn_cast<Attribute>(dim)) {
1219- if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
1220- staticExpandedShape.push_back (intAttr.getInt ());
1221- } else {
1222- staticExpandedShape.push_back (ShapedType::kDynamic );
1223- }
1224- } else {
1225- staticExpandedShape.push_back (ShapedType::kDynamic );
1226- }
1227- }
1198+ std::tie (staticExpandedShape, std::ignore) =
1199+ decomposeMixedValues (expandedShape);
12281200
12291201 auto newExpandOp = rewriter.create <tensor::ExpandShapeOp>(
12301202 loc,
12311203 RankedTensorType::get (staticExpandedShape,
12321204 padOp.getSource ().getType ().getElementType ()),
1233- padOp.getSource (), reassociations);
1205+ padOp.getSource (), reassociations, expandedShape );
12341206
1235- auto newPadOp = rewriter.create <tensor::PadOp>(
1236- loc , expandOp.getType (), newExpandOp.getResult (), newLow, newHigh,
1207+ rewriter.replaceOpWithNewOp <tensor::PadOp>(
1208+ expandOp , expandOp.getType (), newExpandOp.getResult (), newLow, newHigh,
12371209 padOp.getConstantPaddingValue (), padOp.getNofold ());
1238-
1239- rewriter.replaceOp (expandOp, newPadOp.getResult ());
12401210 return success ();
12411211 }
12421212
@@ -1284,7 +1254,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12841254 for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
12851255 if (reInd.size () > 1 ) {
12861256 for (auto dimIdx : reInd) {
1287- if (!isZero (low[dimIdx]) || !isZero (high[dimIdx])) {
1257+ if (!isConstantIntValue (low[dimIdx], 0 ) ||
1258+ !isConstantIntValue (high[dimIdx], 0 )) {
12881259 return failure ();
12891260 }
12901261 }
@@ -1316,7 +1287,7 @@ struct FoldReshapeWithProducerPadOpByCollapsing
13161287 OpFoldResult l = low[reInd[0 ]];
13171288 OpFoldResult h = high[reInd[0 ]];
13181289
1319- if (!isZero (l ) || !isZero (h )) {
1290+ if (!isConstantIntValue (l, 0 ) || !isConstantIntValue (h, 0 )) {
13201291 auto srcType = cast<RankedTensorType>(padOp.getSource ().getType ());
13211292 int64_t originalSize = srcType.getDimSize (reInd[0 ]);
13221293
@@ -1350,12 +1321,10 @@ struct FoldReshapeWithProducerPadOpByCollapsing
13501321 auto newCollapseOp = rewriter.create <tensor::CollapseShapeOp>(
13511322 loc, newCollapseType, padOp.getSource (), reassociations);
13521323
1353- auto newPadOp = rewriter.create <tensor::PadOp>(
1354- loc , resultType, newCollapseOp.getResult (), newLow, newHigh,
1324+ rewriter.replaceOpWithNewOp <tensor::PadOp>(
1325+ collapseOp , resultType, newCollapseOp.getResult (), newLow, newHigh,
13551326 padOp.getConstantPaddingValue (), padOp.getNofold ());
13561327
1357- rewriter.replaceOp (collapseOp, newPadOp.getResult ());
1358-
13591328 return success ();
13601329 }
13611330
0 commit comments