@@ -4842,11 +4842,93 @@ AMDGPUTargetLowering::foldFreeOpFromSelect(TargetLowering::DAGCombinerInfo &DCI,
48424842 return SDValue ();
48434843}
48444844
4845+ // Detect when CMP and SELECT use the same constant and fold them to avoid
4846+ // loading the constant twice. Specifically handles patterns like:
4847+ // %cmp = icmp eq i32 %val, 4242
4848+ // %sel = select i1 %cmp, i32 4242, i32 %other
4849+ // It can be optimized to reuse %val instead of 4242 in select.
4850+ static SDValue
4851+ foldCmpSelectWithSharedConstant (SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
4852+ const AMDGPUSubtarget *ST) {
4853+ SDValue Cond = N->getOperand (0 );
4854+ SDValue TrueVal = N->getOperand (1 );
4855+ SDValue FalseVal = N->getOperand (2 );
4856+
4857+ // Check if condition is a comparison.
4858+ if (Cond.getOpcode () != ISD::SETCC)
4859+ return SDValue ();
4860+
4861+ SDValue LHS = Cond.getOperand (0 );
4862+ SDValue RHS = Cond.getOperand (1 );
4863+ ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand (2 ))->get ();
4864+
4865+ bool isFloatingPoint = LHS.getValueType ().isFloatingPoint ();
4866+ bool isInteger = LHS.getValueType ().isInteger ();
4867+
4868+ // Handle simple floating-point and integer types only.
4869+ if (!isFloatingPoint && !isInteger)
4870+ return SDValue ();
4871+
4872+ bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
4873+ bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
4874+ if (!isEquality && !isNonEquality)
4875+ return SDValue ();
4876+
4877+ SDValue ArgVal, ConstVal;
4878+ if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
4879+ (isInteger && isa<ConstantSDNode>(RHS))) {
4880+ ConstVal = RHS;
4881+ ArgVal = LHS;
4882+ } else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
4883+ (isInteger && isa<ConstantSDNode>(LHS))) {
4884+ ConstVal = LHS;
4885+ ArgVal = RHS;
4886+ } else {
4887+ return SDValue ();
4888+ }
4889+
4890+ // Check if constant should not be optimized - early return if not.
4891+ if (isFloatingPoint) {
4892+ const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF ();
4893+ const GCNSubtarget *GCNST = static_cast <const GCNSubtarget *>(ST);
4894+
4895+ // Only optimize normal floating-point values, skip optimization for
4896+ // inlinable floating-point constants.
4897+ if (!Val.isNormal () || GCNST->getInstrInfo ()->isInlineConstant (Val))
4898+ return SDValue ();
4899+ } else {
4900+ int64_t IntVal = cast<ConstantSDNode>(ConstVal)->getSExtValue ();
4901+
4902+ // Skip optimization for inlinable integer immediates.
4903+ // Inlinable immediates include: -16 to 64 (inclusive).
4904+ if (IntVal >= -16 && IntVal <= 64 )
4905+ return SDValue ();
4906+ }
4907+
4908+ // For equality and non-equality comparisons, patterns:
4909+ // select (setcc x, const), const, y -> select (setcc x, const), x, y
4910+ // select (setccinv x, const), y, const -> select (setccinv x, const), y, x
4911+ if (!(isEquality && TrueVal == ConstVal) &&
4912+ !(isNonEquality && FalseVal == ConstVal))
4913+ return SDValue ();
4914+
4915+ SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
4916+ SDValue SelectRHS =
4917+ (isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
4918+ return DCI.DAG .getNode (ISD::SELECT, SDLoc (N), N->getValueType (0 ), Cond,
4919+ SelectLHS, SelectRHS);
4920+ }
4921+
48454922SDValue AMDGPUTargetLowering::performSelectCombine (SDNode *N,
48464923 DAGCombinerInfo &DCI) const {
48474924 if (SDValue Folded = foldFreeOpFromSelect (DCI, SDValue (N, 0 )))
48484925 return Folded;
48494926
4927+ // Try to fold CMP + SELECT patterns with shared constants (both FP and
4928+ // integer).
4929+ if (SDValue Folded = foldCmpSelectWithSharedConstant (N, DCI, Subtarget))
4930+ return Folded;
4931+
48504932 SDValue Cond = N->getOperand (0 );
48514933 if (Cond.getOpcode () != ISD::SETCC)
48524934 return SDValue ();
0 commit comments