@@ -4843,11 +4843,94 @@ AMDGPUTargetLowering::foldFreeOpFromSelect(TargetLowering::DAGCombinerInfo &DCI,
48434843 return SDValue ();
48444844}
48454845
4846+ // Detect when CMP and SELECT use the same constant and fold them to avoid
4847+ // loading the constant twice. Specifically handles patterns like:
4848+ // %cmp = icmp eq i32 %val, 4242
4849+ // %sel = select i1 %cmp, i32 4242, i32 %other
4850+ // It can be optimized to reuse %val instead of 4242 in select.
4851+ static SDValue
4852+ foldCmpSelectWithSharedConstant (SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
4853+ const AMDGPUSubtarget *ST) {
4854+ SDValue Cond = N->getOperand (0 );
4855+ SDValue TrueVal = N->getOperand (1 );
4856+ SDValue FalseVal = N->getOperand (2 );
4857+
4858+ // Check if condition is a comparison.
4859+ if (Cond.getOpcode () != ISD::SETCC)
4860+ return SDValue ();
4861+
4862+ SDValue LHS = Cond.getOperand (0 );
4863+ SDValue RHS = Cond.getOperand (1 );
4864+ ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand (2 ))->get ();
4865+
4866+ bool isFloatingPoint = LHS.getValueType ().isFloatingPoint ();
4867+ bool isInteger = LHS.getValueType ().isInteger ();
4868+
4869+ // Handle simple floating-point and integer types only.
4870+ if (!isFloatingPoint && !isInteger)
4871+ return SDValue ();
4872+
4873+ bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
4874+ bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
4875+ if (!isEquality && !isNonEquality)
4876+ return SDValue ();
4877+
4878+ SDValue ArgVal, ConstVal;
4879+ if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
4880+ (isInteger && isa<ConstantSDNode>(RHS))) {
4881+ ConstVal = RHS;
4882+ ArgVal = LHS;
4883+ } else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
4884+ (isInteger && isa<ConstantSDNode>(LHS))) {
4885+ ConstVal = LHS;
4886+ ArgVal = RHS;
4887+ } else {
4888+ return SDValue ();
4889+ }
4890+
4891+ // Check if constant should not be optimized - early return if not.
4892+ if (isFloatingPoint) {
4893+ const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF ();
4894+ const GCNSubtarget *GCNST = static_cast <const GCNSubtarget *>(ST);
4895+
4896+ // Only optimize normal floating-point values (finite, non-zero, and
4897+ // non-subnormal as per IEEE 754), skip optimization for inlinable
4898+ // floating-point constants.
4899+ if (!Val.isNormal () || GCNST->getInstrInfo ()->isInlineConstant (Val))
4900+ return SDValue ();
4901+ } else {
4902+ int64_t IntVal = cast<ConstantSDNode>(ConstVal)->getSExtValue ();
4903+
4904+ // Skip optimization for inlinable integer immediates.
4905+ // Inlinable immediates include: -16 to 64 (inclusive).
4906+ if (IntVal >= -16 && IntVal <= 64 )
4907+ return SDValue ();
4908+ }
4909+
4910+ // For equality and non-equality comparisons, patterns:
4911+ // select (setcc x, const), const, y -> select (setcc x, const), x, y
4912+ // select (setccinv x, const), y, const -> select (setccinv x, const), y, x
4913+ if (!(isEquality && TrueVal == ConstVal) &&
4914+ !(isNonEquality && FalseVal == ConstVal))
4915+ return SDValue ();
4916+
4917+ SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
4918+ SDValue SelectRHS =
4919+ (isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
4920+ return DCI.DAG .getNode (ISD::SELECT, SDLoc (N), N->getValueType (0 ), Cond,
4921+ SelectLHS, SelectRHS);
4922+ }
4923+
48464924SDValue AMDGPUTargetLowering::performSelectCombine (SDNode *N,
48474925 DAGCombinerInfo &DCI) const {
48484926 if (SDValue Folded = foldFreeOpFromSelect (DCI, SDValue (N, 0 )))
48494927 return Folded;
48504928
4929+ // Try to fold CMP + SELECT patterns with shared constants (both FP and
4930+ // integer).
4931+ if (SDValue Folded = foldCmpSelectWithSharedConstant (N, DCI, Subtarget))
4932+ return Folded;
4933+
48514934 SDValue Cond = N->getOperand (0 );
48524935 if (Cond.getOpcode () != ISD::SETCC)
48534936 return SDValue ();
0 commit comments