Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,18 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
return getLhs();

// max(max(a, b), b) -> max(a, b)
// max(max(a, b), a) -> max(a, b)
if (auto max = getLhs().getDefiningOp<MaximumFOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getLhs();

// max(a, max(a, b)) -> max(a, b)
// max(b, max(a, b)) -> max(a, b)
Comment on lines +1119 to +1126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if instead of covering both cases here and for other ops, we should instead use the fact that these are commutative and first move the other min/max operand to the RHS?

Copy link
Author

@ziliangzl ziliangzl Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Currently, the Commutative trait only provides a generic fold that moves constant operands to the end:

LogicalResult
OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// Nothing to fold if there are not at least 2 operands.
if (op->getNumOperands() < 2)
return failure();
// Move all constant operands to the end.
OpOperand *operandsBegin = op->getOpOperands().begin();
auto isNonConstant = [&](OpOperand &o) {
return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]);
};
auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant);
auto *newConstantIt = std::stable_partition(
firstConstantIt, op->getOpOperands().end(), isNonConstant);
// Return success if the op was modified.
return success(firstConstantIt != newConstantIt);
}

so it helps in canonicalize patterns like max(max(x, c0), c1) -> max(x, max(c0, c1)), where there are constants.

But for cases like fold patterns in this pr max(max(a, b), a) -> max(a, b) , no constants are involved, so the commutative fold doesn’t normalize operands. That’s why I explicitly check both lhs and rhs here.

An alternative would be to normalize operands inside the fold itself (e.g., always move the nested max to the RHS), but for now I kept both conditions for clarity.
cc @joker-eph

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to normalize operands inside the fold itself (e.g., always move the nested max to the RHS)

Yes, I suggested this in the other revision: doing this in the trait would align all the commutative operation and then in-turn simplify all the folders (to avoid checking redundant forms)

if (auto max = getRhs().getDefiningOp<MaximumFOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getRhs();

return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
Expand All @@ -1134,6 +1146,18 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
return getLhs();

// max(max(a, b), b) -> max(a, b)
// max(max(a, b), a) -> max(a, b)
if (auto max = getLhs().getDefiningOp<MaxNumFOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getLhs();

// max(a, max(a, b)) -> max(a, b)
// max(b, max(a, b)) -> max(a, b)
if (auto max = getRhs().getDefiningOp<MaxNumFOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getRhs();

return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
}

Expand All @@ -1156,6 +1180,30 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}

// max(max(a, b), b) -> max(a, b)
// max(max(a, b), a) -> max(a, b)
if (auto max = getLhs().getDefiningOp<MaxSIOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getLhs();

// max(a, max(a, b)) -> max(a, b)
// max(b, max(a, b)) -> max(a, b)
Comment on lines +1183 to +1190
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look fine to me

if (auto max = getRhs().getDefiningOp<MaxSIOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getRhs();

// max(min(a, b), a) -> a
// max(min(b, a), a) -> a
if (auto min = getLhs().getDefiningOp<MinSIOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getRhs();

// max(a, min(a, b)) -> a
// max(a, min(b, a)) -> a
Comment on lines +1195 to +1202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if this is correct when one of the values is poison and gets discarded, but this should be fine since the optimization makes the result more defined: https://alive2.llvm.org/ce/z/cdoy8x

if (auto min = getRhs().getDefiningOp<MinSIOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getLhs();

return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::smax(a, b);
Expand All @@ -1181,6 +1229,30 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}

// max(max(a, b), b) -> max(a, b)
// max(max(a, b), a) -> max(a, b)
if (auto max = getLhs().getDefiningOp<MaxUIOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getLhs();

// max(a, max(a, b)) -> max(a, b)
// max(b, max(a, b)) -> max(a, b)
if (auto max = getRhs().getDefiningOp<MaxUIOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getRhs();

// max(min(a, b), a) -> a
// max(min(b, a), a) -> a
if (auto min = getLhs().getDefiningOp<MinUIOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getRhs();

// max(a, min(a, b)) -> a
// max(a, min(b, a)) -> a
if (auto min = getRhs().getDefiningOp<MinUIOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getLhs();

return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::umax(a, b);
Expand All @@ -1200,6 +1272,18 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
return getLhs();

// min(min(a, b), b) -> min(a, b)
// min(min(a, b), a) -> min(a, b)
if (auto min = getLhs().getDefiningOp<MinimumFOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getLhs();

// min(a, min(a, b)) -> min(a, b)
// min(b, min(a, b)) -> min(a, b)
if (auto min = getRhs().getDefiningOp<MinimumFOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getRhs();

return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
Expand All @@ -1218,6 +1302,18 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
return getLhs();

// min(min(a, b), b) -> min(a, b)
// min(min(a, b), a) -> min(a, b)
if (auto min = getLhs().getDefiningOp<MinNumFOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getLhs();

// min(a, min(a, b)) -> min(a, b)
// min(b, min(a, b)) -> min(a, b)
if (auto min = getRhs().getDefiningOp<MinNumFOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getRhs();

return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
Expand All @@ -1242,6 +1338,30 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}

// min(min(a, b), b) -> min(a, b)
// min(min(a, b), a) -> min(a, b)
if (auto min = getLhs().getDefiningOp<MinSIOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getLhs();

// min(a, min(a, b)) -> min(a, b)
// min(b, min(a, b)) -> min(a, b)
if (auto min = getRhs().getDefiningOp<MinSIOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getRhs();

// min(max(a, b), a) -> a
// min(max(b, a), a) -> a
if (auto max = getLhs().getDefiningOp<MaxSIOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getRhs();

// min(a, max(a, b)) -> a
// min(a, max(b, a)) -> a
if (auto max = getRhs().getDefiningOp<MaxSIOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getLhs();

return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::smin(a, b);
Expand All @@ -1267,6 +1387,30 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}

// min(min(a, b), b) -> min(a, b)
// min(min(a, b), a) -> min(a, b)
if (auto min = getLhs().getDefiningOp<MinUIOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getLhs();

// min(a, min(a, b)) -> min(a, b)
// min(b, min(a, b)) -> min(a, b)
if (auto min = getRhs().getDefiningOp<MinUIOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getRhs();

// min(max(a, b), a) -> a
// min(max(b, a), a) -> a
if (auto max = getLhs().getDefiningOp<MaxUIOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getRhs();

// min(a, max(a, b)) -> a
// min(a, max(b, a)) -> a
if (auto max = getRhs().getDefiningOp<MaxUIOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getLhs();

return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::umin(a, b);
Expand Down
Loading