Skip to content

Commit fbd551c

Browse files
committed
fix: minor fixes
1 parent c9af1f8 commit fbd551c

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15008,8 +15008,9 @@ struct IfBinaryOpToSelectBinaryOp final
1500815008
Region &branchRegion, bool isTrueBranch,
1500915009
PatternRewriter &rewriter) const {
1501015010
auto binOp = branchVal.getDefiningOp<BinaryOpType>();
15011-
if (!binOp)
15011+
if (!binOp) {
1501215012
return nullptr;
15013+
}
1501315014

1501415015
Value base = nullptr;
1501515016
Value other = nullptr;
@@ -15026,26 +15027,31 @@ struct IfBinaryOpToSelectBinaryOp final
1502615027
matchLhs = false;
1502715028
}
1502815029

15029-
if (!base || !other)
15030+
if (!base || !other) {
1503015031
return nullptr;
15032+
}
1503115033

15032-
bool isCommutative = binOp->template hasTrait<mlir::OpTrait::IsCommutative>();
15034+
bool isCommutative =
15035+
binOp->template hasTrait<mlir::OpTrait::IsCommutative>();
1503315036

1503415037
// For non-commutative ops, we only support hoisting if the identity works.
1503515038
// We assume getIdentityValueForOp returns a right identity.
1503615039
// Thus, for non-commutative ops, the base MUST be the LHS.
15037-
if (!isCommutative && !matchLhs)
15040+
if (!isCommutative && !matchLhs) {
1503815041
return nullptr;
15042+
}
1503915043

1504015044
// Check if 'other' can be hoisted (not defined in the current branch)
15041-
if (&branchRegion == other.getParentRegion())
15045+
if (&branchRegion == other.getParentRegion()) {
1504215046
return nullptr;
15047+
}
1504315048

1504415049
auto elemType = cast<ShapedType>(branchVal.getType()).getElementType();
1504515050
Value identity = stablehlo::getIdentityValueForOp<BinaryOpType>(
1504615051
rewriter, op.getLoc(), elemType);
15047-
if (!identity)
15052+
if (!identity) {
1504815053
return nullptr;
15054+
}
1504915055

1505015056
// Broadcast identity to match the type of 'other'
1505115057
auto otherType = cast<RankedTensorType>(other.getType());
@@ -15060,10 +15066,11 @@ struct IfBinaryOpToSelectBinaryOp final
1506015066

1506115067
Value selected = stablehlo::SelectOp::create(rewriter, op.getLoc(), pred,
1506215068
trueSelectVal, falseSelectVal);
15063-
if (matchLhs)
15069+
if (matchLhs) {
1506415070
return BinaryOpType::create(rewriter, op.getLoc(), base, selected);
15065-
else
15071+
} else {
1506615072
return BinaryOpType::create(rewriter, op.getLoc(), selected, base);
15073+
}
1506715074
}
1506815075
};
1506915076

test/lit_tests/if_binop_to_select.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ module {
3838
// Case with Mul and identity 1.0
3939
func.func @hoist_mul(%pred: tensor<i1>, %base: tensor<f32>, %other: tensor<f32>) -> tensor<f32> {
4040
%0 = "stablehlo.if"(%pred) ({
41-
%1 = stablehlo.mul %other, %base : tensor<f32>
41+
%1 = stablehlo.multiply %other, %base : tensor<f32>
4242
"stablehlo.return"(%1) : (tensor<f32>) -> ()
4343
}, {
4444
"stablehlo.return"(%base) : (tensor<f32>) -> ()
@@ -48,7 +48,7 @@ module {
4848
// CHECK-LABEL: func.func @hoist_mul
4949
// CHECK: %[[ONE:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
5050
// CHECK: %[[SEL:.+]] = stablehlo.select %{{.+}}, %{{.+}}, %[[ONE]]
51-
// CHECK: %[[MUL:.+]] = stablehlo.mul %{{.+}}, %[[SEL]]
51+
// CHECK: %[[MUL:.+]] = stablehlo.multiply %{{.+}}, %[[SEL]]
5252

5353
// Case with Max and identity -Inf
5454
func.func @hoist_max(%pred: tensor<i1>, %base: tensor<f32>, %other: tensor<f32>) -> tensor<f32> {

0 commit comments

Comments
 (0)