@@ -15008,8 +15008,9 @@ struct IfBinaryOpToSelectBinaryOp final
1500815008 Region &branchRegion, bool isTrueBranch,
1500915009 PatternRewriter &rewriter) const {
1501015010 auto binOp = branchVal.getDefiningOp<BinaryOpType>();
15011- if (!binOp)
15011+ if (!binOp) {
1501215012 return nullptr;
15013+ }
1501315014
1501415015 Value base = nullptr;
1501515016 Value other = nullptr;
@@ -15026,26 +15027,31 @@ struct IfBinaryOpToSelectBinaryOp final
1502615027 matchLhs = false;
1502715028 }
1502815029
15029- if (!base || !other)
15030+ if (!base || !other) {
1503015031 return nullptr;
15032+ }
1503115033
15032- bool isCommutative = binOp->template hasTrait<mlir::OpTrait::IsCommutative>();
15034+ bool isCommutative =
15035+ binOp->template hasTrait<mlir::OpTrait::IsCommutative>();
1503315036
1503415037 // For non-commutative ops, we only support hoisting if the identity works.
1503515038 // We assume getIdentityValueForOp returns a right identity.
1503615039 // Thus, for non-commutative ops, the base MUST be the LHS.
15037- if (!isCommutative && !matchLhs)
15040+ if (!isCommutative && !matchLhs) {
1503815041 return nullptr;
15042+ }
1503915043
1504015044 // Check if 'other' can be hoisted (not defined in the current branch)
15041- if (&branchRegion == other.getParentRegion())
15045+ if (&branchRegion == other.getParentRegion()) {
1504215046 return nullptr;
15047+ }
1504315048
1504415049 auto elemType = cast<ShapedType>(branchVal.getType()).getElementType();
1504515050 Value identity = stablehlo::getIdentityValueForOp<BinaryOpType>(
1504615051 rewriter, op.getLoc(), elemType);
15047- if (!identity)
15052+ if (!identity) {
1504815053 return nullptr;
15054+ }
1504915055
1505015056 // Broadcast identity to match the type of 'other'
1505115057 auto otherType = cast<RankedTensorType>(other.getType());
@@ -15060,10 +15066,11 @@ struct IfBinaryOpToSelectBinaryOp final
1506015066
1506115067 Value selected = stablehlo::SelectOp::create(rewriter, op.getLoc(), pred,
1506215068 trueSelectVal, falseSelectVal);
15063- if (matchLhs)
15069+ if (matchLhs) {
1506415070 return BinaryOpType::create(rewriter, op.getLoc(), base, selected);
15065- else
15071+ } else {
1506615072 return BinaryOpType::create(rewriter, op.getLoc(), selected, base);
15073+ }
1506715074 }
1506815075};
1506915076
0 commit comments