Skip to content

Commit 5185c3b

Browse files
committed
[AMDGPU] Try to reuse register with the constant from compare in v_cndmask.
For some targets, the optimization X == Const ? X : Y -> X == Const ? Const : Y can cause extra register usage or redundant immediate encoding for the constant in cndmask generated from the ternary operation. This patch detects such cases and reuses the register from the compare instruction that already holds the constant, instead of materializing it again for cndmask. The optimization avoids immediates that can be encoded into cndmask instruction (including +-0.0), as well as !isNormal() constants.
1 parent 5fd319f commit 5185c3b

File tree

3 files changed

+2466
-0
lines changed

3 files changed

+2466
-0
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4842,11 +4842,93 @@ AMDGPUTargetLowering::foldFreeOpFromSelect(TargetLowering::DAGCombinerInfo &DCI,
48424842
return SDValue();
48434843
}
48444844

4845+
// Detect when CMP and SELECT use the same constant and fold them to avoid
4846+
// loading the constant twice. Specifically handles patterns like:
4847+
// %cmp = icmp eq i32 %val, 4242
4848+
// %sel = select i1 %cmp, i32 4242, i32 %other
4849+
// It can be optimized to reuse %val instead of 4242 in select.
4850+
static SDValue
4851+
foldCmpSelectWithSharedConstant(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
4852+
const AMDGPUSubtarget *ST) {
4853+
SDValue Cond = N->getOperand(0);
4854+
SDValue TrueVal = N->getOperand(1);
4855+
SDValue FalseVal = N->getOperand(2);
4856+
4857+
// Check if condition is a comparison.
4858+
if (Cond.getOpcode() != ISD::SETCC)
4859+
return SDValue();
4860+
4861+
SDValue LHS = Cond.getOperand(0);
4862+
SDValue RHS = Cond.getOperand(1);
4863+
ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
4864+
4865+
bool isFloatingPoint = LHS.getValueType().isFloatingPoint();
4866+
bool isInteger = LHS.getValueType().isInteger();
4867+
4868+
// Handle simple floating-point and integer types only.
4869+
if (!isFloatingPoint && !isInteger)
4870+
return SDValue();
4871+
4872+
bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
4873+
bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
4874+
if (!isEquality && !isNonEquality)
4875+
return SDValue();
4876+
4877+
SDValue ArgVal, ConstVal;
4878+
if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
4879+
(isInteger && isa<ConstantSDNode>(RHS))) {
4880+
ConstVal = RHS;
4881+
ArgVal = LHS;
4882+
} else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
4883+
(isInteger && isa<ConstantSDNode>(LHS))) {
4884+
ConstVal = LHS;
4885+
ArgVal = RHS;
4886+
} else {
4887+
return SDValue();
4888+
}
4889+
4890+
// Check if constant should not be optimized - early return if not.
4891+
if (isFloatingPoint) {
4892+
const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF();
4893+
const GCNSubtarget *GCNST = static_cast<const GCNSubtarget *>(ST);
4894+
4895+
// Only optimize normal floating-point values, skip optimization for
4896+
// inlinable floating-point constants.
4897+
if (!Val.isNormal() || GCNST->getInstrInfo()->isInlineConstant(Val))
4898+
return SDValue();
4899+
} else {
4900+
int64_t IntVal = cast<ConstantSDNode>(ConstVal)->getSExtValue();
4901+
4902+
// Skip optimization for inlinable integer immediates.
4903+
// Inlinable immediates include: -16 to 64 (inclusive).
4904+
if (IntVal >= -16 && IntVal <= 64)
4905+
return SDValue();
4906+
}
4907+
4908+
// For equality and non-equality comparisons, patterns:
4909+
// select (setcc x, const), const, y -> select (setcc x, const), x, y
4910+
// select (setccinv x, const), y, const -> select (setccinv x, const), y, x
4911+
if (!(isEquality && TrueVal == ConstVal) &&
4912+
!(isNonEquality && FalseVal == ConstVal))
4913+
return SDValue();
4914+
4915+
SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
4916+
SDValue SelectRHS =
4917+
(isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
4918+
return DCI.DAG.getNode(ISD::SELECT, SDLoc(N), N->getValueType(0), Cond,
4919+
SelectLHS, SelectRHS);
4920+
}
4921+
48454922
SDValue AMDGPUTargetLowering::performSelectCombine(SDNode *N,
48464923
DAGCombinerInfo &DCI) const {
48474924
if (SDValue Folded = foldFreeOpFromSelect(DCI, SDValue(N, 0)))
48484925
return Folded;
48494926

4927+
// Try to fold CMP + SELECT patterns with shared constants (both FP and
4928+
// integer).
4929+
if (SDValue Folded = foldCmpSelectWithSharedConstant(N, DCI, Subtarget))
4930+
return Folded;
4931+
48504932
SDValue Cond = N->getOperand(0);
48514933
if (Cond.getOpcode() != ISD::SETCC)
48524934
return SDValue();

0 commit comments

Comments
 (0)