Skip to content

Commit e29d921

Browse files
committed
feat: compare op simplification for non-negative operands
1 parent 7c8b87a commit e29d921

File tree

6 files changed

+378
-52
lines changed

6 files changed

+378
-52
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 128 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12500,63 +12500,148 @@ struct CompareOpCanon final
1250012500

1250112501
// Bail out on non-integer comparison.
1250212502
// TODO: Support more comparison types.
12503-
using stablehlo::ComparisonType;
12504-
std::optional<ComparisonType> compType = op.getCompareType();
12505-
if (!compType ||
12506-
!llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED},
12507-
*compType)) {
12508-
return failure();
12509-
}
12510-
1251112503
using stablehlo::ComparisonDirection;
12504+
using stablehlo::ComparisonType;
1251212505
ComparisonDirection direction = op.getComparisonDirection();
1251312506
Value lhs = op.getLhs();
1251412507
Value rhs = op.getRhs();
12508+
std::optional<ComparisonType> compType = op.getCompareType();
1251512509

12516-
if (lhs == rhs) {
12517-
switch (direction) {
12518-
case ComparisonDirection::EQ:
12519-
case ComparisonDirection::GE:
12520-
case ComparisonDirection::LE: {
12521-
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12522-
op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true)));
12523-
return success();
12510+
TypedAttr lhsAttr;
12511+
matchPattern(lhs, m_Constant(&lhsAttr));
12512+
12513+
TypedAttr rhsAttr;
12514+
matchPattern(rhs, m_Constant(&rhsAttr));
12515+
12516+
if (compType &&
12517+
llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED},
12518+
*compType)) {
12519+
12520+
if (lhs == rhs) {
12521+
switch (direction) {
12522+
case ComparisonDirection::EQ:
12523+
case ComparisonDirection::GE:
12524+
case ComparisonDirection::LE: {
12525+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12526+
op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true)));
12527+
return success();
12528+
}
12529+
case ComparisonDirection::GT:
12530+
case ComparisonDirection::LT:
12531+
case ComparisonDirection::NE: {
12532+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12533+
op, rewriter.getZeroAttr(type));
12534+
return success();
12535+
}
12536+
}
12537+
llvm_unreachable("Unhandled case");
1252412538
}
12525-
case ComparisonDirection::GT:
12526-
case ComparisonDirection::LT:
12527-
case ComparisonDirection::NE: {
12528-
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12529-
op, rewriter.getZeroAttr(type));
12539+
12540+
// The canonical form has the constant operand as the RHS.
12541+
if (lhsAttr && !rhsAttr) {
12542+
rewriter.modifyOpInPlace(op, [&op, direction, lhs, rhs] {
12543+
op.setComparisonDirection(invertDirection(direction));
12544+
op->setOperands(ValueRange{rhs, lhs});
12545+
});
1253012546
return success();
1253112547
}
12548+
12549+
if (Attribute res;
12550+
lhsAttr && rhsAttr &&
12551+
(res = constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>(
12552+
ArrayRef<Attribute>({lhsAttr, rhsAttr}), op.getType(),
12553+
[direction, kind = *compType](const APInt &a, const APInt &b) {
12554+
return calculateComp(kind, direction, a, b);
12555+
}))) {
12556+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, res);
12557+
return success();
1253212558
}
12533-
llvm_unreachable("Unhandled case");
1253412559
}
1253512560

12536-
TypedAttr lhsAttr;
12537-
matchPattern(lhs, m_Constant(&lhsAttr));
12561+
auto simplifyNonNegative = [&](Attribute attr,
12562+
ComparisonDirection direction,
12563+
Value otherOperand) -> LogicalResult {
12564+
APInt val;
12565+
APFloat valFloat(APFloat::IEEEdouble());
12566+
bool isSplat = false, isFloat = false;
12567+
12568+
if (auto dense = dyn_cast<DenseElementsAttr>(attr)) {
12569+
if (dense.isSplat()) {
12570+
auto splatAttr = dense.getSplatValue<Attribute>();
12571+
if (auto intAttr = dyn_cast<IntegerAttr>(splatAttr)) {
12572+
val = intAttr.getValue();
12573+
isSplat = true;
12574+
} else if (auto floatAttr = dyn_cast<FloatAttr>(splatAttr)) {
12575+
valFloat = floatAttr.getValue();
12576+
isFloat = true;
12577+
isSplat = true;
12578+
}
12579+
}
12580+
}
1253812581

12539-
TypedAttr rhsAttr;
12540-
matchPattern(rhs, m_Constant(&rhsAttr));
12582+
if (isSplat) {
12583+
bool alwaysTrue = false;
12584+
bool alwaysFalse = false;
1254112585

12542-
// The canonical form has the constant operand as the RHS.
12543-
if (lhsAttr && !rhsAttr) {
12544-
rewriter.modifyOpInPlace(op, [&op, direction, lhs, rhs] {
12545-
op.setComparisonDirection(invertDirection(direction));
12546-
op->setOperands(ValueRange{rhs, lhs});
12547-
});
12548-
return success();
12549-
}
12586+
// Check if the constant is negative or zero
12587+
bool isNegative = isFloat ? valFloat.isNegative() : val.isNegative();
12588+
bool isZeroVal = isFloat ? valFloat.isZero() : val.isZero();
1255012589

12551-
if (Attribute res;
12552-
lhsAttr && rhsAttr &&
12553-
(res = constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>(
12554-
ArrayRef<Attribute>({lhsAttr, rhsAttr}), op.getType(),
12555-
[direction, kind = *compType](const APInt &a, const APInt &b) {
12556-
return calculateComp(kind, direction, a, b);
12557-
}))) {
12558-
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, res);
12559-
return success();
12590+
if ((compType && *compType == ComparisonType::SIGNED) ||
12591+
(isFloat && (!compType || *compType == ComparisonType::FLOAT) &&
12592+
guaranteedNoNanResult(otherOperand, rewriter))) {
12593+
if (isNegative) {
12594+
switch (direction) {
12595+
case ComparisonDirection::EQ:
12596+
case ComparisonDirection::LE:
12597+
case ComparisonDirection::LT:
12598+
alwaysFalse = true;
12599+
break;
12600+
case ComparisonDirection::NE:
12601+
case ComparisonDirection::GE:
12602+
case ComparisonDirection::GT:
12603+
alwaysTrue = true;
12604+
break;
12605+
}
12606+
} else if (isZeroVal) {
12607+
if (direction == ComparisonDirection::LT)
12608+
alwaysFalse = true;
12609+
if (direction == ComparisonDirection::GE)
12610+
alwaysTrue = true;
12611+
}
12612+
} else if (compType && *compType == ComparisonType::UNSIGNED) {
12613+
if (isZeroVal) {
12614+
if (direction == ComparisonDirection::LT)
12615+
alwaysFalse = true;
12616+
if (direction == ComparisonDirection::GE)
12617+
alwaysTrue = true;
12618+
}
12619+
}
12620+
12621+
if (alwaysTrue) {
12622+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12623+
op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true)));
12624+
return success();
12625+
}
12626+
if (alwaysFalse) {
12627+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12628+
op, rewriter.getZeroAttr(type));
12629+
return success();
12630+
}
12631+
}
12632+
return failure();
12633+
};
12634+
12635+
if (rhsAttr && guaranteedNonNegativeResult(lhs, rewriter)) {
12636+
if (succeeded(simplifyNonNegative(rhsAttr, direction, lhs))) {
12637+
return success();
12638+
}
12639+
}
12640+
if (lhsAttr && guaranteedNonNegativeResult(rhs, rewriter)) {
12641+
if (succeeded(
12642+
simplifyNonNegative(lhsAttr, invertDirection(direction), rhs))) {
12643+
return success();
12644+
}
1256012645
}
1256112646

1256212647
return failure();

0 commit comments

Comments
 (0)