Skip to content

Commit 0faf084

Browse files
committed
fix upon review
1 parent 9c38ad5 commit 0faf084

File tree

1 file changed

+8
-17
lines changed

1 file changed

+8
-17
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,20 +1217,15 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12171217
collapseOp.getReassociationIndices();
12181218
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
12191219
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
1220-
1220+
SmallVector<OpFoldResult> newLow, newHigh;
12211221
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1222-
if (reInd.size() > 1) {
1223-
for (auto dimIdx : reInd) {
1224-
if (!isConstantIntValue(low[dimIdx], 0) ||
1225-
!isConstantIntValue(high[dimIdx], 0)) {
1226-
return failure();
1227-
}
1228-
}
1222+
if (reInd.size() > 1 && llvm::any_of(reInd, [&](int64_t dimIdx) {
1223+
return !isConstantIntValue(low[dimIdx], 0) ||
1224+
!isConstantIntValue(high[dimIdx], 0);
1225+
})) {
1226+
return failure();
12291227
}
1230-
}
12311228

1232-
SmallVector<OpFoldResult> newLow, newHigh;
1233-
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
12341229
newLow.push_back(low[reInd[0]]);
12351230
newHigh.push_back(high[reInd[0]]);
12361231
}
@@ -1244,16 +1239,12 @@ struct FoldReshapeWithProducerPadOpByCollapsing
12441239
OpFoldResult l = low[reInd[0]];
12451240
OpFoldResult h = high[reInd[0]];
12461241
if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
1247-
auto mixedSize =
1248-
tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]);
1249-
auto dimSize = getConstantIntValue(mixedSize);
1250-
assert(dimSize.has_value() && "Expected static dimension");
1251-
collapsedShape[inDimIdx] = *dimSize;
1242+
collapsedShape[inDimIdx] = padOp.getSourceType().getShape()[reInd[0]];
12521243
}
12531244
}
12541245

12551246
auto newCollapseType = RankedTensorType::get(
1256-
collapsedShape, padOp.getSource().getType().getElementType());
1247+
collapsedShape, padOp.getSourceType().getElementType());
12571248
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
12581249
loc, newCollapseType, padOp.getSource(), reassociations);
12591250

0 commit comments

Comments
 (0)