Skip to content

Commit e650c4b

Browse files
authored
[NFC][AMDGPU] Move cmp+select arguments optimization to SIISelLowering. (#150929)
As requested in #148740.
1 parent ea480cc commit e650c4b

File tree

3 files changed

+77
-83
lines changed

3 files changed

+77
-83
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4846,94 +4846,11 @@ AMDGPUTargetLowering::foldFreeOpFromSelect(TargetLowering::DAGCombinerInfo &DCI,
48464846
return SDValue();
48474847
}
48484848

4849-
// Detect when CMP and SELECT use the same constant and fold them to avoid
4850-
// loading the constant twice. Specifically handles patterns like:
4851-
// %cmp = icmp eq i32 %val, 4242
4852-
// %sel = select i1 %cmp, i32 4242, i32 %other
4853-
// It can be optimized to reuse %val instead of 4242 in select.
4854-
static SDValue
4855-
foldCmpSelectWithSharedConstant(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
4856-
const AMDGPUSubtarget *ST) {
4857-
SDValue Cond = N->getOperand(0);
4858-
SDValue TrueVal = N->getOperand(1);
4859-
SDValue FalseVal = N->getOperand(2);
4860-
4861-
// Check if condition is a comparison.
4862-
if (Cond.getOpcode() != ISD::SETCC)
4863-
return SDValue();
4864-
4865-
SDValue LHS = Cond.getOperand(0);
4866-
SDValue RHS = Cond.getOperand(1);
4867-
ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
4868-
4869-
bool isFloatingPoint = LHS.getValueType().isFloatingPoint();
4870-
bool isInteger = LHS.getValueType().isInteger();
4871-
4872-
// Handle simple floating-point and integer types only.
4873-
if (!isFloatingPoint && !isInteger)
4874-
return SDValue();
4875-
4876-
bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
4877-
bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
4878-
if (!isEquality && !isNonEquality)
4879-
return SDValue();
4880-
4881-
SDValue ArgVal, ConstVal;
4882-
if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
4883-
(isInteger && isa<ConstantSDNode>(RHS))) {
4884-
ConstVal = RHS;
4885-
ArgVal = LHS;
4886-
} else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
4887-
(isInteger && isa<ConstantSDNode>(LHS))) {
4888-
ConstVal = LHS;
4889-
ArgVal = RHS;
4890-
} else {
4891-
return SDValue();
4892-
}
4893-
4894-
// Check if constant should not be optimized - early return if not.
4895-
if (isFloatingPoint) {
4896-
const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF();
4897-
const GCNSubtarget *GCNST = static_cast<const GCNSubtarget *>(ST);
4898-
4899-
// Only optimize normal floating-point values (finite, non-zero, and
4900-
// non-subnormal as per IEEE 754), skip optimization for inlinable
4901-
// floating-point constants.
4902-
if (!Val.isNormal() || GCNST->getInstrInfo()->isInlineConstant(Val))
4903-
return SDValue();
4904-
} else {
4905-
int64_t IntVal = cast<ConstantSDNode>(ConstVal)->getSExtValue();
4906-
4907-
// Skip optimization for inlinable integer immediates.
4908-
// Inlinable immediates include: -16 to 64 (inclusive).
4909-
if (IntVal >= -16 && IntVal <= 64)
4910-
return SDValue();
4911-
}
4912-
4913-
// For equality and non-equality comparisons, patterns:
4914-
// select (setcc x, const), const, y -> select (setcc x, const), x, y
4915-
// select (setccinv x, const), y, const -> select (setccinv x, const), y, x
4916-
if (!(isEquality && TrueVal == ConstVal) &&
4917-
!(isNonEquality && FalseVal == ConstVal))
4918-
return SDValue();
4919-
4920-
SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
4921-
SDValue SelectRHS =
4922-
(isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
4923-
return DCI.DAG.getNode(ISD::SELECT, SDLoc(N), N->getValueType(0), Cond,
4924-
SelectLHS, SelectRHS);
4925-
}
4926-
49274849
SDValue AMDGPUTargetLowering::performSelectCombine(SDNode *N,
49284850
DAGCombinerInfo &DCI) const {
49294851
if (SDValue Folded = foldFreeOpFromSelect(DCI, SDValue(N, 0)))
49304852
return Folded;
49314853

4932-
// Try to fold CMP + SELECT patterns with shared constants (both FP and
4933-
// integer).
4934-
if (SDValue Folded = foldCmpSelectWithSharedConstant(N, DCI, Subtarget))
4935-
return Folded;
4936-
49374854
SDValue Cond = N->getOperand(0);
49384855
if (Cond.getOpcode() != ISD::SETCC)
49394856
return SDValue();

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15896,6 +15896,78 @@ SDValue SITargetLowering::performClampCombine(SDNode *N,
1589615896
return SDValue(CSrc, 0);
1589715897
}
1589815898

15899+
SDValue SITargetLowering::performSelectCombine(SDNode *N,
15900+
DAGCombinerInfo &DCI) const {
15901+
15902+
// Try to fold CMP + SELECT patterns with shared constants (both FP and
15903+
// integer).
15904+
// Detect when CMP and SELECT use the same constant and fold them to avoid
15905+
// loading the constant twice. Specifically handles patterns like:
15906+
// %cmp = icmp eq i32 %val, 4242
15907+
// %sel = select i1 %cmp, i32 4242, i32 %other
15908+
// It can be optimized to reuse %val instead of 4242 in select.
15909+
SDValue Cond = N->getOperand(0);
15910+
SDValue TrueVal = N->getOperand(1);
15911+
SDValue FalseVal = N->getOperand(2);
15912+
15913+
// Check if condition is a comparison.
15914+
if (Cond.getOpcode() != ISD::SETCC)
15915+
return SDValue();
15916+
15917+
SDValue LHS = Cond.getOperand(0);
15918+
SDValue RHS = Cond.getOperand(1);
15919+
ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
15920+
15921+
bool isFloatingPoint = LHS.getValueType().isFloatingPoint();
15922+
bool isInteger = LHS.getValueType().isInteger();
15923+
15924+
// Handle simple floating-point and integer types only.
15925+
if (!isFloatingPoint && !isInteger)
15926+
return SDValue();
15927+
15928+
bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
15929+
bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
15930+
if (!isEquality && !isNonEquality)
15931+
return SDValue();
15932+
15933+
SDValue ArgVal, ConstVal;
15934+
if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
15935+
(isInteger && isa<ConstantSDNode>(RHS))) {
15936+
ConstVal = RHS;
15937+
ArgVal = LHS;
15938+
} else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
15939+
(isInteger && isa<ConstantSDNode>(LHS))) {
15940+
ConstVal = LHS;
15941+
ArgVal = RHS;
15942+
} else {
15943+
return SDValue();
15944+
}
15945+
15946+
// Skip optimization for inlinable immediates.
15947+
if (isFloatingPoint) {
15948+
const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF();
15949+
if (!Val.isNormal() || Subtarget->getInstrInfo()->isInlineConstant(Val))
15950+
return SDValue();
15951+
} else {
15952+
if (AMDGPU::isInlinableIntLiteral(
15953+
cast<ConstantSDNode>(ConstVal)->getSExtValue()))
15954+
return SDValue();
15955+
}
15956+
15957+
// For equality and non-equality comparisons, patterns:
15958+
// select (setcc x, const), const, y -> select (setcc x, const), x, y
15959+
// select (setccinv x, const), y, const -> select (setccinv x, const), y, x
15960+
if (!(isEquality && TrueVal == ConstVal) &&
15961+
!(isNonEquality && FalseVal == ConstVal))
15962+
return SDValue();
15963+
15964+
SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
15965+
SDValue SelectRHS =
15966+
(isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
15967+
return DCI.DAG.getNode(ISD::SELECT, SDLoc(N), N->getValueType(0), Cond,
15968+
SelectLHS, SelectRHS);
15969+
}
15970+
1589915971
SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
1590015972
DAGCombinerInfo &DCI) const {
1590115973
switch (N->getOpcode()) {
@@ -15944,6 +16016,10 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
1594416016
return performFMulCombine(N, DCI);
1594516017
case ISD::SETCC:
1594616018
return performSetCCCombine(N, DCI);
16019+
case ISD::SELECT:
16020+
if (auto Res = performSelectCombine(N, DCI))
16021+
return Res;
16022+
break;
1594716023
case ISD::FMAXNUM:
1594816024
case ISD::FMINNUM:
1594916025
case ISD::FMAXNUM_IEEE:

llvm/lib/Target/AMDGPU/SIISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
211211
SDValue performExtractVectorEltCombine(SDNode *N, DAGCombinerInfo &DCI) const;
212212
SDValue performInsertVectorEltCombine(SDNode *N, DAGCombinerInfo &DCI) const;
213213
SDValue performFPRoundCombine(SDNode *N, DAGCombinerInfo &DCI) const;
214+
SDValue performSelectCombine(SDNode *N, DAGCombinerInfo &DCI) const;
214215

215216
SDValue reassociateScalarOps(SDNode *N, SelectionDAG &DAG) const;
216217
unsigned getFusedOpcode(const SelectionDAG &DAG,

0 commit comments

Comments
 (0)