@@ -12500,63 +12500,148 @@ struct CompareOpCanon final
1250012500
1250112501 // Bail out on non-integer comparison.
1250212502 // TODO: Support more comparison types.
12503- using stablehlo::ComparisonType;
12504- std::optional<ComparisonType> compType = op.getCompareType();
12505- if (!compType ||
12506- !llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED},
12507- *compType)) {
12508- return failure();
12509- }
12510-
1251112503 using stablehlo::ComparisonDirection;
12504+ using stablehlo::ComparisonType;
1251212505 ComparisonDirection direction = op.getComparisonDirection();
1251312506 Value lhs = op.getLhs();
1251412507 Value rhs = op.getRhs();
12508+ std::optional<ComparisonType> compType = op.getCompareType();
1251512509
12516- if (lhs == rhs) {
12517- switch (direction) {
12518- case ComparisonDirection::EQ:
12519- case ComparisonDirection::GE:
12520- case ComparisonDirection::LE: {
12521- rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12522- op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true)));
12523- return success();
12510+ TypedAttr lhsAttr;
12511+ matchPattern(lhs, m_Constant(&lhsAttr));
12512+
12513+ TypedAttr rhsAttr;
12514+ matchPattern(rhs, m_Constant(&rhsAttr));
12515+
12516+ if (compType &&
12517+ llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED},
12518+ *compType)) {
12519+
12520+ if (lhs == rhs) {
12521+ switch (direction) {
12522+ case ComparisonDirection::EQ:
12523+ case ComparisonDirection::GE:
12524+ case ComparisonDirection::LE: {
12525+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12526+ op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true)));
12527+ return success();
12528+ }
12529+ case ComparisonDirection::GT:
12530+ case ComparisonDirection::LT:
12531+ case ComparisonDirection::NE: {
12532+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12533+ op, rewriter.getZeroAttr(type));
12534+ return success();
12535+ }
12536+ }
12537+ llvm_unreachable("Unhandled case");
1252412538 }
12525- case ComparisonDirection::GT:
12526- case ComparisonDirection::LT:
12527- case ComparisonDirection::NE: {
12528- rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12529- op, rewriter.getZeroAttr(type));
12539+
12540+ // The canonical form has the constant operand as the RHS.
12541+ if (lhsAttr && !rhsAttr) {
12542+ rewriter.modifyOpInPlace(op, [&op, direction, lhs, rhs] {
12543+ op.setComparisonDirection(invertDirection(direction));
12544+ op->setOperands(ValueRange{rhs, lhs});
12545+ });
1253012546 return success();
1253112547 }
12548+
12549+ if (Attribute res;
12550+ lhsAttr && rhsAttr &&
12551+ (res = constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>(
12552+ ArrayRef<Attribute>({lhsAttr, rhsAttr}), op.getType(),
12553+ [direction, kind = *compType](const APInt &a, const APInt &b) {
12554+ return calculateComp(kind, direction, a, b);
12555+ }))) {
12556+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, res);
12557+ return success();
1253212558 }
12533- llvm_unreachable("Unhandled case");
1253412559 }
1253512560
12536- TypedAttr lhsAttr;
12537- matchPattern(lhs, m_Constant(&lhsAttr));
12561+ auto simplifyNonNegative = [&](Attribute attr,
12562+ ComparisonDirection direction,
12563+ Value otherOperand) -> LogicalResult {
12564+ APInt val;
12565+ APFloat valFloat(APFloat::IEEEdouble());
12566+ bool isSplat = false, isFloat = false;
12567+
12568+ if (auto dense = dyn_cast<DenseElementsAttr>(attr)) {
12569+ if (dense.isSplat()) {
12570+ auto splatAttr = dense.getSplatValue<Attribute>();
12571+ if (auto intAttr = dyn_cast<IntegerAttr>(splatAttr)) {
12572+ val = intAttr.getValue();
12573+ isSplat = true;
12574+ } else if (auto floatAttr = dyn_cast<FloatAttr>(splatAttr)) {
12575+ valFloat = floatAttr.getValue();
12576+ isFloat = true;
12577+ isSplat = true;
12578+ }
12579+ }
12580+ }
1253812581
12539- TypedAttr rhsAttr;
12540- matchPattern(rhs, m_Constant(&rhsAttr));
12582+ if (isSplat) {
12583+ bool alwaysTrue = false;
12584+ bool alwaysFalse = false;
1254112585
12542- // The canonical form has the constant operand as the RHS.
12543- if (lhsAttr && !rhsAttr) {
12544- rewriter.modifyOpInPlace(op, [&op, direction, lhs, rhs] {
12545- op.setComparisonDirection(invertDirection(direction));
12546- op->setOperands(ValueRange{rhs, lhs});
12547- });
12548- return success();
12549- }
12586+ // Check if the constant is negative or zero
12587+ bool isNegative = isFloat ? valFloat.isNegative() : val.isNegative();
12588+ bool isZeroVal = isFloat ? valFloat.isZero() : val.isZero();
1255012589
12551- if (Attribute res;
12552- lhsAttr && rhsAttr &&
12553- (res = constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>(
12554- ArrayRef<Attribute>({lhsAttr, rhsAttr}), op.getType(),
12555- [direction, kind = *compType](const APInt &a, const APInt &b) {
12556- return calculateComp(kind, direction, a, b);
12557- }))) {
12558- rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, res);
12559- return success();
12590+ if ((compType && *compType == ComparisonType::SIGNED) ||
12591+ (isFloat && (!compType || *compType == ComparisonType::FLOAT) &&
12592+ guaranteedNoNanResult(otherOperand, rewriter))) {
12593+ if (isNegative) {
12594+ switch (direction) {
12595+ case ComparisonDirection::EQ:
12596+ case ComparisonDirection::LE:
12597+ case ComparisonDirection::LT:
12598+ alwaysFalse = true;
12599+ break;
12600+ case ComparisonDirection::NE:
12601+ case ComparisonDirection::GE:
12602+ case ComparisonDirection::GT:
12603+ alwaysTrue = true;
12604+ break;
12605+ }
12606+ } else if (isZeroVal) {
12607+ if (direction == ComparisonDirection::LT)
12608+ alwaysFalse = true;
12609+ if (direction == ComparisonDirection::GE)
12610+ alwaysTrue = true;
12611+ }
12612+ } else if (compType && *compType == ComparisonType::UNSIGNED) {
12613+ if (isZeroVal) {
12614+ if (direction == ComparisonDirection::LT)
12615+ alwaysFalse = true;
12616+ if (direction == ComparisonDirection::GE)
12617+ alwaysTrue = true;
12618+ }
12619+ }
12620+
12621+ if (alwaysTrue) {
12622+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12623+ op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true)));
12624+ return success();
12625+ }
12626+ if (alwaysFalse) {
12627+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
12628+ op, rewriter.getZeroAttr(type));
12629+ return success();
12630+ }
12631+ }
12632+ return failure();
12633+ };
12634+
12635+ if (rhsAttr && guaranteedNonNegativeResult(lhs, rewriter)) {
12636+ if (succeeded(simplifyNonNegative(rhsAttr, direction, lhs))) {
12637+ return success();
12638+ }
12639+ }
12640+ if (lhsAttr && guaranteedNonNegativeResult(rhs, rewriter)) {
12641+ if (succeeded(
12642+ simplifyNonNegative(lhsAttr, invertDirection(direction), rhs))) {
12643+ return success();
12644+ }
1256012645 }
1256112646
1256212647 return failure();
0 commit comments