Skip to content

Commit c9af1f8

Browse files
committed
feat: more ops
1 parent 0f673a9 commit c9af1f8

File tree

4 files changed

+281
-89
lines changed

4 files changed

+281
-89
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 79 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -14919,92 +14919,22 @@ struct IfBinaryOpToSelectBinaryOp final
1491914919
&op.getFalseBranch() != falseVal.getParentRegion();
1492014920

1492114921
// Case 1: true branch has binop, false branch returns base
14922-
// if(pred) { binop(base, other) } else { base }
14923-
// => binop(base, select(pred, other, identity))
14924-
if (auto binOp = trueVal.getDefiningOp<BinaryOpType>()) {
14925-
Value base = nullptr;
14926-
Value other = nullptr;
14927-
14928-
// Check if one operand of binop matches the false branch return
14929-
if (binOp.getLhs() == falseVal && falseDefinedOutside) {
14930-
base = falseVal;
14931-
other = binOp.getRhs();
14932-
} else if (binOp.getRhs() == falseVal && falseDefinedOutside) {
14933-
base = falseVal;
14934-
other = binOp.getLhs();
14935-
}
14936-
14937-
if (base && other) {
14938-
// Check if 'other' can be hoisted (defined outside true branch)
14939-
if (&op.getTrueBranch() != other.getParentRegion()) {
14940-
auto elemType =
14941-
cast<ShapedType>(op.getResult(i).getType()).getElementType();
14942-
Value identity = stablehlo::getIdentityValueForOp<BinaryOpType>(
14943-
rewriter, op.getLoc(), elemType);
14944-
if (identity) {
14945-
// Broadcast identity to match the type of 'other'
14946-
auto otherType = cast<RankedTensorType>(other.getType());
14947-
if (identity.getType() != otherType) {
14948-
identity = stablehlo::BroadcastInDimOp::create(
14949-
rewriter, op.getLoc(), otherType, identity,
14950-
rewriter.getDenseI64ArrayAttr({}));
14951-
}
14952-
14953-
Value selected = stablehlo::SelectOp::create(
14954-
rewriter, op.getLoc(), pred, other, identity);
14955-
Value result =
14956-
BinaryOpType::create(rewriter, op.getLoc(), base, selected);
14957-
hoistedResults[i] = result;
14958-
anyHoisted = true;
14959-
continue;
14960-
}
14961-
}
14962-
}
14922+
if (Value hoisted = tryHoistBinaryOp(
14923+
op, pred, trueVal, falseVal, falseDefinedOutside,
14924+
op.getTrueBranch(), /*isTrueBranch=*/true, rewriter)) {
14925+
hoistedResults[i] = hoisted;
14926+
anyHoisted = true;
14927+
continue;
1496314928
}
1496414929

1496514930
// Case 2: false branch has binop, true branch returns base (symmetric
14966-
// case) if(pred) { base } else { binop(base, other) }
14967-
// => binop(base, select(pred, identity, other))
14968-
if (auto binOp = falseVal.getDefiningOp<BinaryOpType>()) {
14969-
Value base = nullptr;
14970-
Value other = nullptr;
14971-
14972-
// Check if one operand of binop matches the true branch return
14973-
if (binOp.getLhs() == trueVal && trueDefinedOutside) {
14974-
base = trueVal;
14975-
other = binOp.getRhs();
14976-
} else if (binOp.getRhs() == trueVal && trueDefinedOutside) {
14977-
base = trueVal;
14978-
other = binOp.getLhs();
14979-
}
14980-
14981-
if (base && other) {
14982-
// Check if 'other' can be hoisted (defined outside false branch)
14983-
if (&op.getFalseBranch() != other.getParentRegion()) {
14984-
auto elemType =
14985-
cast<ShapedType>(op.getResult(i).getType()).getElementType();
14986-
Value identity = stablehlo::getIdentityValueForOp<BinaryOpType>(
14987-
rewriter, op.getLoc(), elemType);
14988-
if (identity) {
14989-
// Broadcast identity to match the type of 'other'
14990-
auto otherType = cast<RankedTensorType>(other.getType());
14991-
if (identity.getType() != otherType) {
14992-
identity = stablehlo::BroadcastInDimOp::create(
14993-
rewriter, op.getLoc(), otherType, identity,
14994-
rewriter.getDenseI64ArrayAttr({}));
14995-
}
14996-
14997-
// Note: pred is true => use identity, pred is false => use other
14998-
Value selected = stablehlo::SelectOp::create(
14999-
rewriter, op.getLoc(), pred, identity, other);
15000-
Value result =
15001-
BinaryOpType::create(rewriter, op.getLoc(), base, selected);
15002-
hoistedResults[i] = result;
15003-
anyHoisted = true;
15004-
continue;
15005-
}
15006-
}
15007-
}
14931+
// case)
14932+
if (Value hoisted = tryHoistBinaryOp(
14933+
op, pred, falseVal, trueVal, trueDefinedOutside,
14934+
op.getFalseBranch(), /*isTrueBranch=*/false, rewriter)) {
14935+
hoistedResults[i] = hoisted;
14936+
anyHoisted = true;
14937+
continue;
1500814938
}
1500914939
}
1501014940

@@ -15071,6 +15001,70 @@ struct IfBinaryOpToSelectBinaryOp final
1507115001
rewriter.replaceOp(op, finalResults);
1507215002
return success();
1507315003
}
15004+
15005+
private:
15006+
Value tryHoistBinaryOp(stablehlo::IfOp op, Value pred, Value branchVal,
15007+
Value otherBranchVal, bool otherBranchDefinedOutside,
15008+
Region &branchRegion, bool isTrueBranch,
15009+
PatternRewriter &rewriter) const {
15010+
auto binOp = branchVal.getDefiningOp<BinaryOpType>();
15011+
if (!binOp)
15012+
return nullptr;
15013+
15014+
Value base = nullptr;
15015+
Value other = nullptr;
15016+
bool matchLhs = false;
15017+
15018+
// Check if one operand of binop matches the other branch return
15019+
if (binOp.getLhs() == otherBranchVal && otherBranchDefinedOutside) {
15020+
base = otherBranchVal;
15021+
other = binOp.getRhs();
15022+
matchLhs = true;
15023+
} else if (binOp.getRhs() == otherBranchVal && otherBranchDefinedOutside) {
15024+
base = otherBranchVal;
15025+
other = binOp.getLhs();
15026+
matchLhs = false;
15027+
}
15028+
15029+
if (!base || !other)
15030+
return nullptr;
15031+
15032+
bool isCommutative = binOp->template hasTrait<mlir::OpTrait::IsCommutative>();
15033+
15034+
// For non-commutative ops, we only support hoisting if the identity works.
15035+
// We assume getIdentityValueForOp returns a right identity.
15036+
// Thus, for non-commutative ops, the base MUST be the LHS.
15037+
if (!isCommutative && !matchLhs)
15038+
return nullptr;
15039+
15040+
// Check if 'other' can be hoisted (not defined in the current branch)
15041+
if (&branchRegion == other.getParentRegion())
15042+
return nullptr;
15043+
15044+
auto elemType = cast<ShapedType>(branchVal.getType()).getElementType();
15045+
Value identity = stablehlo::getIdentityValueForOp<BinaryOpType>(
15046+
rewriter, op.getLoc(), elemType);
15047+
if (!identity)
15048+
return nullptr;
15049+
15050+
// Broadcast identity to match the type of 'other'
15051+
auto otherType = cast<RankedTensorType>(other.getType());
15052+
if (identity.getType() != otherType) {
15053+
identity = stablehlo::BroadcastInDimOp::create(
15054+
rewriter, op.getLoc(), otherType, identity,
15055+
rewriter.getDenseI64ArrayAttr({}));
15056+
}
15057+
15058+
Value trueSelectVal = isTrueBranch ? other : identity;
15059+
Value falseSelectVal = isTrueBranch ? identity : other;
15060+
15061+
Value selected = stablehlo::SelectOp::create(rewriter, op.getLoc(), pred,
15062+
trueSelectVal, falseSelectVal);
15063+
if (matchLhs)
15064+
return BinaryOpType::create(rewriter, op.getLoc(), base, selected);
15065+
else
15066+
return BinaryOpType::create(rewriter, op.getLoc(), selected, base);
15067+
}
1507415068
};
1507515069

1507615070
// https://github.com/llvm/llvm-project/blob/74d8f3952c4acf6d57948983d7c5b0d0a7763c28/mlir/lib/Dialect/SCF/IR/SCF.cpp#L2313
@@ -34142,6 +34136,8 @@ struct EnzymeHLOOptPass
3414234136
IfBinaryOpToSelectBinaryOp<stablehlo::MulOp>,
3414334137
IfBinaryOpToSelectBinaryOp<stablehlo::MinOp>,
3414434138
IfBinaryOpToSelectBinaryOp<stablehlo::MaxOp>,
34139+
IfBinaryOpToSelectBinaryOp<stablehlo::SubtractOp>,
34140+
IfBinaryOpToSelectBinaryOp<stablehlo::DivOp>,
3414534141
IfPredPropagation,
3414634142
ImagOpCanon,
3414734143
MergeConsecutiveReshapes,

src/enzyme_ad/jax/Utils.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,6 +2072,19 @@ Value getIdentityValueForOp<stablehlo::MaxOp>(OpBuilder &builder, Location loc,
20722072
return nullptr;
20732073
}
20742074

2075+
template <>
2076+
Value getIdentityValueForOp<stablehlo::SubtractOp>(OpBuilder &builder,
2077+
Location loc,
2078+
Type elemType) {
2079+
return getIdentityValueForOp<stablehlo::AddOp>(builder, loc, elemType);
2080+
}
2081+
2082+
template <>
2083+
Value getIdentityValueForOp<stablehlo::DivOp>(OpBuilder &builder, Location loc,
2084+
Type elemType) {
2085+
return getIdentityValueForOp<stablehlo::MulOp>(builder, loc, elemType);
2086+
}
2087+
20752088
// Identity values for bitwise logical ops.
20762089
// OR/XOR: identity = 0
20772090
template <>
@@ -2110,9 +2123,11 @@ Value getIdentityValue(OpBuilder &builder, Location loc, Type elemType,
21102123
return TypeSwitch<Operation *, Value>(op)
21112124
.Case<stablehlo::AddOp, stablehlo::MulOp, stablehlo::MinOp,
21122125
stablehlo::MaxOp, stablehlo::OrOp, stablehlo::XorOp,
2113-
stablehlo::AndOp>([&](auto binOp) {
2114-
return getIdentityValueForOp<decltype(binOp)>(builder, loc, elemType);
2115-
})
2126+
stablehlo::AndOp, stablehlo::SubtractOp, stablehlo::DivOp>(
2127+
[&](auto binOp) {
2128+
return getIdentityValueForOp<decltype(binOp)>(builder, loc,
2129+
elemType);
2130+
})
21162131
.Default([&](Operation *op) -> Value { return nullptr; });
21172132
}
21182133

test/lit_tests/autobatching/addpositiveloop.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ module {
1414
%2 = stablehlo.compare LT, %iterArg, %c_2 : (tensor<i64>, tensor<i64>) -> tensor<i1>
1515
stablehlo.return %2 : tensor<i1>
1616
} do {
17-
%2 = stablehlo.add %c_0, %iterArg {enzymexla.bounds = [[1, 10]]} : tensor<i64>
18-
%3 = stablehlo.convert %2 {enzymexla.bounds = [[1, 10]]} : (tensor<i64>) -> tensor<i32>
19-
%4 = stablehlo.subtract %3, %c {enzymexla.bounds = [[0, 9]]} : tensor<i32>
17+
%2 = stablehlo.add %c_0, %iterArg : tensor<i64>
18+
%3 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32>
19+
%4 = stablehlo.subtract %3, %c : tensor<i32>
2020
%5 = stablehlo.dynamic_slice %arg0, %4, sizes = [1] : (tensor<10xf64>, tensor<i32>) -> tensor<1xf64>
2121
%6 = stablehlo.reshape %5 : (tensor<1xf64>) -> tensor<f64>
2222
%7 = stablehlo.dynamic_slice %0, %iterArg, sizes = [1] : (tensor<10xi1>, tensor<i64>) -> tensor<1xi1>

0 commit comments

Comments
 (0)