@@ -14919,92 +14919,22 @@ struct IfBinaryOpToSelectBinaryOp final
1491914919 &op.getFalseBranch() != falseVal.getParentRegion();
1492014920
1492114921 // Case 1: true branch has binop, false branch returns base
14922- // if(pred) { binop(base, other) } else { base }
14923- // => binop(base, select(pred, other, identity))
14924- if (auto binOp = trueVal.getDefiningOp<BinaryOpType>()) {
14925- Value base = nullptr;
14926- Value other = nullptr;
14927-
14928- // Check if one operand of binop matches the false branch return
14929- if (binOp.getLhs() == falseVal && falseDefinedOutside) {
14930- base = falseVal;
14931- other = binOp.getRhs();
14932- } else if (binOp.getRhs() == falseVal && falseDefinedOutside) {
14933- base = falseVal;
14934- other = binOp.getLhs();
14935- }
14936-
14937- if (base && other) {
14938- // Check if 'other' can be hoisted (defined outside true branch)
14939- if (&op.getTrueBranch() != other.getParentRegion()) {
14940- auto elemType =
14941- cast<ShapedType>(op.getResult(i).getType()).getElementType();
14942- Value identity = stablehlo::getIdentityValueForOp<BinaryOpType>(
14943- rewriter, op.getLoc(), elemType);
14944- if (identity) {
14945- // Broadcast identity to match the type of 'other'
14946- auto otherType = cast<RankedTensorType>(other.getType());
14947- if (identity.getType() != otherType) {
14948- identity = stablehlo::BroadcastInDimOp::create(
14949- rewriter, op.getLoc(), otherType, identity,
14950- rewriter.getDenseI64ArrayAttr({}));
14951- }
14952-
14953- Value selected = stablehlo::SelectOp::create(
14954- rewriter, op.getLoc(), pred, other, identity);
14955- Value result =
14956- BinaryOpType::create(rewriter, op.getLoc(), base, selected);
14957- hoistedResults[i] = result;
14958- anyHoisted = true;
14959- continue;
14960- }
14961- }
14962- }
14922+ if (Value hoisted = tryHoistBinaryOp(
14923+ op, pred, trueVal, falseVal, falseDefinedOutside,
14924+ op.getTrueBranch(), /*isTrueBranch=*/true, rewriter)) {
14925+ hoistedResults[i] = hoisted;
14926+ anyHoisted = true;
14927+ continue;
1496314928 }
1496414929
1496514930 // Case 2: false branch has binop, true branch returns base (symmetric
14966- // case) if(pred) { base } else { binop(base, other) }
14967- // => binop(base, select(pred, identity, other))
14968- if (auto binOp = falseVal.getDefiningOp<BinaryOpType>()) {
14969- Value base = nullptr;
14970- Value other = nullptr;
14971-
14972- // Check if one operand of binop matches the true branch return
14973- if (binOp.getLhs() == trueVal && trueDefinedOutside) {
14974- base = trueVal;
14975- other = binOp.getRhs();
14976- } else if (binOp.getRhs() == trueVal && trueDefinedOutside) {
14977- base = trueVal;
14978- other = binOp.getLhs();
14979- }
14980-
14981- if (base && other) {
14982- // Check if 'other' can be hoisted (defined outside false branch)
14983- if (&op.getFalseBranch() != other.getParentRegion()) {
14984- auto elemType =
14985- cast<ShapedType>(op.getResult(i).getType()).getElementType();
14986- Value identity = stablehlo::getIdentityValueForOp<BinaryOpType>(
14987- rewriter, op.getLoc(), elemType);
14988- if (identity) {
14989- // Broadcast identity to match the type of 'other'
14990- auto otherType = cast<RankedTensorType>(other.getType());
14991- if (identity.getType() != otherType) {
14992- identity = stablehlo::BroadcastInDimOp::create(
14993- rewriter, op.getLoc(), otherType, identity,
14994- rewriter.getDenseI64ArrayAttr({}));
14995- }
14996-
14997- // Note: pred is true => use identity, pred is false => use other
14998- Value selected = stablehlo::SelectOp::create(
14999- rewriter, op.getLoc(), pred, identity, other);
15000- Value result =
15001- BinaryOpType::create(rewriter, op.getLoc(), base, selected);
15002- hoistedResults[i] = result;
15003- anyHoisted = true;
15004- continue;
15005- }
15006- }
15007- }
14931+ // case)
14932+ if (Value hoisted = tryHoistBinaryOp(
14933+ op, pred, falseVal, trueVal, trueDefinedOutside,
14934+ op.getFalseBranch(), /*isTrueBranch=*/false, rewriter)) {
14935+ hoistedResults[i] = hoisted;
14936+ anyHoisted = true;
14937+ continue;
1500814938 }
1500914939 }
1501014940
@@ -15071,6 +15001,70 @@ struct IfBinaryOpToSelectBinaryOp final
1507115001 rewriter.replaceOp(op, finalResults);
1507215002 return success();
1507315003 }
15004+
15005+ private:
15006+ Value tryHoistBinaryOp(stablehlo::IfOp op, Value pred, Value branchVal,
15007+ Value otherBranchVal, bool otherBranchDefinedOutside,
15008+ Region &branchRegion, bool isTrueBranch,
15009+ PatternRewriter &rewriter) const {
15010+ auto binOp = branchVal.getDefiningOp<BinaryOpType>();
15011+ if (!binOp)
15012+ return nullptr;
15013+
15014+ Value base = nullptr;
15015+ Value other = nullptr;
15016+ bool matchLhs = false;
15017+
15018+ // Check if one operand of binop matches the other branch return
15019+ if (binOp.getLhs() == otherBranchVal && otherBranchDefinedOutside) {
15020+ base = otherBranchVal;
15021+ other = binOp.getRhs();
15022+ matchLhs = true;
15023+ } else if (binOp.getRhs() == otherBranchVal && otherBranchDefinedOutside) {
15024+ base = otherBranchVal;
15025+ other = binOp.getLhs();
15026+ matchLhs = false;
15027+ }
15028+
15029+ if (!base || !other)
15030+ return nullptr;
15031+
15032+ bool isCommutative = binOp->template hasTrait<mlir::OpTrait::IsCommutative>();
15033+
15034+ // For non-commutative ops, we only support hoisting if the identity works.
15035+ // We assume getIdentityValueForOp returns a right identity.
15036+ // Thus, for non-commutative ops, the base MUST be the LHS.
15037+ if (!isCommutative && !matchLhs)
15038+ return nullptr;
15039+
15040+ // Check if 'other' can be hoisted (not defined in the current branch)
15041+ if (&branchRegion == other.getParentRegion())
15042+ return nullptr;
15043+
15044+ auto elemType = cast<ShapedType>(branchVal.getType()).getElementType();
15045+ Value identity = stablehlo::getIdentityValueForOp<BinaryOpType>(
15046+ rewriter, op.getLoc(), elemType);
15047+ if (!identity)
15048+ return nullptr;
15049+
15050+ // Broadcast identity to match the type of 'other'
15051+ auto otherType = cast<RankedTensorType>(other.getType());
15052+ if (identity.getType() != otherType) {
15053+ identity = stablehlo::BroadcastInDimOp::create(
15054+ rewriter, op.getLoc(), otherType, identity,
15055+ rewriter.getDenseI64ArrayAttr({}));
15056+ }
15057+
15058+ Value trueSelectVal = isTrueBranch ? other : identity;
15059+ Value falseSelectVal = isTrueBranch ? identity : other;
15060+
15061+ Value selected = stablehlo::SelectOp::create(rewriter, op.getLoc(), pred,
15062+ trueSelectVal, falseSelectVal);
15063+ if (matchLhs)
15064+ return BinaryOpType::create(rewriter, op.getLoc(), base, selected);
15065+ else
15066+ return BinaryOpType::create(rewriter, op.getLoc(), selected, base);
15067+ }
1507415068};
1507515069
1507615070// https://github.com/llvm/llvm-project/blob/74d8f3952c4acf6d57948983d7c5b0d0a7763c28/mlir/lib/Dialect/SCF/IR/SCF.cpp#L2313
@@ -34142,6 +34136,8 @@ struct EnzymeHLOOptPass
3414234136 IfBinaryOpToSelectBinaryOp<stablehlo::MulOp>,
3414334137 IfBinaryOpToSelectBinaryOp<stablehlo::MinOp>,
3414434138 IfBinaryOpToSelectBinaryOp<stablehlo::MaxOp>,
34139+ IfBinaryOpToSelectBinaryOp<stablehlo::SubtractOp>,
34140+ IfBinaryOpToSelectBinaryOp<stablehlo::DivOp>,
3414534141 IfPredPropagation,
3414634142 ImagOpCanon,
3414734143 MergeConsecutiveReshapes,
0 commit comments