@@ -4843,11 +4843,94 @@ AMDGPUTargetLowering::foldFreeOpFromSelect(TargetLowering::DAGCombinerInfo &DCI,
4843
4843
return SDValue ();
4844
4844
}
4845
4845
4846
+ // Detect when CMP and SELECT use the same constant and fold them to avoid
4847
+ // loading the constant twice. Specifically handles patterns like:
4848
+ // %cmp = icmp eq i32 %val, 4242
4849
+ // %sel = select i1 %cmp, i32 4242, i32 %other
4850
+ // It can be optimized to reuse %val instead of 4242 in select.
4851
+ static SDValue
4852
+ foldCmpSelectWithSharedConstant (SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
4853
+ const AMDGPUSubtarget *ST) {
4854
+ SDValue Cond = N->getOperand (0 );
4855
+ SDValue TrueVal = N->getOperand (1 );
4856
+ SDValue FalseVal = N->getOperand (2 );
4857
+
4858
+ // Check if condition is a comparison.
4859
+ if (Cond.getOpcode () != ISD::SETCC)
4860
+ return SDValue ();
4861
+
4862
+ SDValue LHS = Cond.getOperand (0 );
4863
+ SDValue RHS = Cond.getOperand (1 );
4864
+ ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand (2 ))->get ();
4865
+
4866
+ bool isFloatingPoint = LHS.getValueType ().isFloatingPoint ();
4867
+ bool isInteger = LHS.getValueType ().isInteger ();
4868
+
4869
+ // Handle simple floating-point and integer types only.
4870
+ if (!isFloatingPoint && !isInteger)
4871
+ return SDValue ();
4872
+
4873
+ bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
4874
+ bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
4875
+ if (!isEquality && !isNonEquality)
4876
+ return SDValue ();
4877
+
4878
+ SDValue ArgVal, ConstVal;
4879
+ if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
4880
+ (isInteger && isa<ConstantSDNode>(RHS))) {
4881
+ ConstVal = RHS;
4882
+ ArgVal = LHS;
4883
+ } else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
4884
+ (isInteger && isa<ConstantSDNode>(LHS))) {
4885
+ ConstVal = LHS;
4886
+ ArgVal = RHS;
4887
+ } else {
4888
+ return SDValue ();
4889
+ }
4890
+
4891
+ // Check if constant should not be optimized - early return if not.
4892
+ if (isFloatingPoint) {
4893
+ const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF ();
4894
+ const GCNSubtarget *GCNST = static_cast <const GCNSubtarget *>(ST);
4895
+
4896
+ // Only optimize normal floating-point values (finite, non-zero, and
4897
+ // non-subnormal as per IEEE 754), skip optimization for inlinable
4898
+ // floating-point constants.
4899
+ if (!Val.isNormal () || GCNST->getInstrInfo ()->isInlineConstant (Val))
4900
+ return SDValue ();
4901
+ } else {
4902
+ int64_t IntVal = cast<ConstantSDNode>(ConstVal)->getSExtValue ();
4903
+
4904
+ // Skip optimization for inlinable integer immediates.
4905
+ // Inlinable immediates include: -16 to 64 (inclusive).
4906
+ if (IntVal >= -16 && IntVal <= 64 )
4907
+ return SDValue ();
4908
+ }
4909
+
4910
+ // For equality and non-equality comparisons, patterns:
4911
+ // select (setcc x, const), const, y -> select (setcc x, const), x, y
4912
+ // select (setccinv x, const), y, const -> select (setccinv x, const), y, x
4913
+ if (!(isEquality && TrueVal == ConstVal) &&
4914
+ !(isNonEquality && FalseVal == ConstVal))
4915
+ return SDValue ();
4916
+
4917
+ SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
4918
+ SDValue SelectRHS =
4919
+ (isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
4920
+ return DCI.DAG .getNode (ISD::SELECT, SDLoc (N), N->getValueType (0 ), Cond,
4921
+ SelectLHS, SelectRHS);
4922
+ }
4923
+
4846
4924
SDValue AMDGPUTargetLowering::performSelectCombine (SDNode *N,
4847
4925
DAGCombinerInfo &DCI) const {
4848
4926
if (SDValue Folded = foldFreeOpFromSelect (DCI, SDValue (N, 0 )))
4849
4927
return Folded;
4850
4928
4929
+ // Try to fold CMP + SELECT patterns with shared constants (both FP and
4930
+ // integer).
4931
+ if (SDValue Folded = foldCmpSelectWithSharedConstant (N, DCI, Subtarget))
4932
+ return Folded;
4933
+
4851
4934
SDValue Cond = N->getOperand (0 );
4852
4935
if (Cond.getOpcode () != ISD::SETCC)
4853
4936
return SDValue ();
0 commit comments