Skip to content

Commit 737d4a4

Browse files
committed
fix upon review
1 parent 57ec657 commit 737d4a4

File tree

1 file changed

+22
-53
lines changed

1 file changed

+22
-53
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 22 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,20 +1102,6 @@ class FoldPadWithProducerReshapeOpByExpansion
11021102
ControlFusionFn controlFoldingReshapes;
11031103
};
11041104

1105-
bool isZero(OpFoldResult value) {
1106-
if (auto attr = dyn_cast<Attribute>(value)) {
1107-
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
1108-
return intAttr.getInt() == 0;
1109-
}
1110-
if (auto val = dyn_cast<Value>(value)) {
1111-
if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
1112-
if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
1113-
return attr.getInt() == 0;
1114-
}
1115-
}
1116-
return false;
1117-
}
1118-
11191105
/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
11201106
/// by bubbling the expand_shape before the pad.
11211107
struct FoldReshapeWithProducerPadOpByExpansion
@@ -1152,19 +1138,17 @@ struct FoldReshapeWithProducerPadOpByExpansion
11521138
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
11531139
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
11541140

1155-
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1156-
OpFoldResult l = low[idx];
1157-
OpFoldResult h = high[idx];
1158-
if (reInd.size() > 1 && (!isZero(l) || !isZero(h)))
1159-
return failure();
1141+
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1142+
if (reInd.size() > 1 &&
1143+
(!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)))
1144+
return rewriter.notifyMatchFailure(
1145+
expandOp, "fusion blocked by non-zero padding");
11601146
}
11611147

11621148
SmallVector<OpFoldResult> newLow, newHigh;
11631149
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1164-
for (size_t i = 0; i < reInd.size(); ++i) {
1165-
newLow.push_back(low[idx]);
1166-
newHigh.push_back(high[idx]);
1167-
}
1150+
newLow.append(reInd.size(), low[idx]);
1151+
newHigh.append(reInd.size(), high[idx]);
11681152
}
11691153

11701154
Location loc = expandOp.getLoc();
@@ -1184,7 +1168,7 @@ struct FoldReshapeWithProducerPadOpByExpansion
11841168
OpFoldResult l = low[inDimIdx];
11851169
OpFoldResult h = high[inDimIdx];
11861170

1187-
if (!isZero(l) || !isZero(h)) {
1171+
if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
11881172
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
11891173
int64_t originalSize = srcType.getDimSize(inDimIdx);
11901174

@@ -1196,47 +1180,33 @@ struct FoldReshapeWithProducerPadOpByExpansion
11961180
} else {
11971181
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
11981182
}
1199-
1200-
for (auto outDimIdx : reInd) {
1201-
expandedShape[outDimIdx] = originalSizeOFR;
1202-
}
1183+
assert(reInd.size() == 1 && "expected single dimension");
1184+
expandedShape[reInd[0]] = originalSizeOFR;
12031185
}
12041186
}
12051187

12061188
for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
12071189
if (dimSize == ShapedType::kDynamic &&
12081190
!isa<Value>(expandedShape[outDimIdx]) &&
12091191
!isa<Attribute>(expandedShape[outDimIdx])) {
1210-
Value actualSize =
1211-
rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
1212-
expandedShape[outDimIdx] = actualSize;
1192+
expandedShape[outDimIdx] =
1193+
tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
12131194
}
12141195
}
12151196

12161197
SmallVector<int64_t> staticExpandedShape;
1217-
for (OpFoldResult dim : expandedShape) {
1218-
if (auto attr = dyn_cast<Attribute>(dim)) {
1219-
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
1220-
staticExpandedShape.push_back(intAttr.getInt());
1221-
} else {
1222-
staticExpandedShape.push_back(ShapedType::kDynamic);
1223-
}
1224-
} else {
1225-
staticExpandedShape.push_back(ShapedType::kDynamic);
1226-
}
1227-
}
1198+
std::tie(staticExpandedShape, std::ignore) =
1199+
decomposeMixedValues(expandedShape);
12281200

12291201
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
12301202
loc,
12311203
RankedTensorType::get(staticExpandedShape,
12321204
padOp.getSource().getType().getElementType()),
1233-
padOp.getSource(), reassociations);
1205+
padOp.getSource(), reassociations, expandedShape);
12341206

1235-
auto newPadOp = rewriter.create<tensor::PadOp>(
1236-
loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
1207+
rewriter.replaceOpWithNewOp<tensor::PadOp>(
1208+
expandOp, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
12371209
padOp.getConstantPaddingValue(), padOp.getNofold());
1238-
1239-
rewriter.replaceOp(expandOp, newPadOp.getResult());
12401210
return success();
12411211
}
12421212

@@ -1284,7 +1254,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12841254
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
12851255
if (reInd.size() > 1) {
12861256
for (auto dimIdx : reInd) {
1287-
if (!isZero(low[dimIdx]) || !isZero(high[dimIdx])) {
1257+
if (!isConstantIntValue(low[dimIdx], 0) ||
1258+
!isConstantIntValue(high[dimIdx], 0)) {
12881259
return failure();
12891260
}
12901261
}
@@ -1316,7 +1287,7 @@ struct FoldReshapeWithProducerPadOpByCollapsing
13161287
OpFoldResult l = low[reInd[0]];
13171288
OpFoldResult h = high[reInd[0]];
13181289

1319-
if (!isZero(l) || !isZero(h)) {
1290+
if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
13201291
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
13211292
int64_t originalSize = srcType.getDimSize(reInd[0]);
13221293

@@ -1350,12 +1321,10 @@ struct FoldReshapeWithProducerPadOpByCollapsing
13501321
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
13511322
loc, newCollapseType, padOp.getSource(), reassociations);
13521323

1353-
auto newPadOp = rewriter.create<tensor::PadOp>(
1354-
loc, resultType, newCollapseOp.getResult(), newLow, newHigh,
1324+
rewriter.replaceOpWithNewOp<tensor::PadOp>(
1325+
collapseOp, resultType, newCollapseOp.getResult(), newLow, newHigh,
13551326
padOp.getConstantPaddingValue(), padOp.getNofold());
13561327

1357-
rewriter.replaceOp(collapseOp, newPadOp.getResult());
1358-
13591328
return success();
13601329
}
13611330

0 commit comments

Comments
 (0)