Skip to content

Commit b7f6abd

Browse files
dfukalovCopilot
andauthored
[AMDGPU] Try to reuse register with the constant from compare in v_cndmask (#148740)
For some targets, the optimization X == Const ? X : Y -> X == Const ? Const : Y can cause extra register usage or redundant immediate encoding for the constant in cndmask generated from the ternary operation. This patch detects such cases and reuses the register from the compare instruction that already holds the constant, instead of materializing it again for cndmask. The optimization avoids immediates that can be encoded into cndmask instruction (including +-0.0), as well as !isNormal() constants. The change is reworked on the base of #131146 --------- Co-authored-by: Copilot <[email protected]>
1 parent 1e4e2b3 commit b7f6abd

File tree

3 files changed

+2467
-0
lines changed

3 files changed

+2467
-0
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4843,11 +4843,94 @@ AMDGPUTargetLowering::foldFreeOpFromSelect(TargetLowering::DAGCombinerInfo &DCI,
48434843
return SDValue();
48444844
}
48454845

4846+
// Detect when CMP and SELECT use the same constant and fold them to avoid
4847+
// loading the constant twice. Specifically handles patterns like:
4848+
// %cmp = icmp eq i32 %val, 4242
4849+
// %sel = select i1 %cmp, i32 4242, i32 %other
4850+
// It can be optimized to reuse %val instead of 4242 in select.
4851+
static SDValue
4852+
foldCmpSelectWithSharedConstant(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
4853+
const AMDGPUSubtarget *ST) {
4854+
SDValue Cond = N->getOperand(0);
4855+
SDValue TrueVal = N->getOperand(1);
4856+
SDValue FalseVal = N->getOperand(2);
4857+
4858+
// Check if condition is a comparison.
4859+
if (Cond.getOpcode() != ISD::SETCC)
4860+
return SDValue();
4861+
4862+
SDValue LHS = Cond.getOperand(0);
4863+
SDValue RHS = Cond.getOperand(1);
4864+
ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
4865+
4866+
bool isFloatingPoint = LHS.getValueType().isFloatingPoint();
4867+
bool isInteger = LHS.getValueType().isInteger();
4868+
4869+
// Handle simple floating-point and integer types only.
4870+
if (!isFloatingPoint && !isInteger)
4871+
return SDValue();
4872+
4873+
bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
4874+
bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
4875+
if (!isEquality && !isNonEquality)
4876+
return SDValue();
4877+
4878+
SDValue ArgVal, ConstVal;
4879+
if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
4880+
(isInteger && isa<ConstantSDNode>(RHS))) {
4881+
ConstVal = RHS;
4882+
ArgVal = LHS;
4883+
} else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
4884+
(isInteger && isa<ConstantSDNode>(LHS))) {
4885+
ConstVal = LHS;
4886+
ArgVal = RHS;
4887+
} else {
4888+
return SDValue();
4889+
}
4890+
4891+
// Check if constant should not be optimized - early return if not.
4892+
if (isFloatingPoint) {
4893+
const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF();
4894+
const GCNSubtarget *GCNST = static_cast<const GCNSubtarget *>(ST);
4895+
4896+
// Only optimize normal floating-point values (finite, non-zero, and
4897+
// non-subnormal as per IEEE 754), skip optimization for inlinable
4898+
// floating-point constants.
4899+
if (!Val.isNormal() || GCNST->getInstrInfo()->isInlineConstant(Val))
4900+
return SDValue();
4901+
} else {
4902+
int64_t IntVal = cast<ConstantSDNode>(ConstVal)->getSExtValue();
4903+
4904+
// Skip optimization for inlinable integer immediates.
4905+
// Inlinable immediates include: -16 to 64 (inclusive).
4906+
if (IntVal >= -16 && IntVal <= 64)
4907+
return SDValue();
4908+
}
4909+
4910+
// For equality and non-equality comparisons, patterns:
4911+
// select (setcc x, const), const, y -> select (setcc x, const), x, y
4912+
// select (setccinv x, const), y, const -> select (setccinv x, const), y, x
4913+
if (!(isEquality && TrueVal == ConstVal) &&
4914+
!(isNonEquality && FalseVal == ConstVal))
4915+
return SDValue();
4916+
4917+
SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
4918+
SDValue SelectRHS =
4919+
(isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
4920+
return DCI.DAG.getNode(ISD::SELECT, SDLoc(N), N->getValueType(0), Cond,
4921+
SelectLHS, SelectRHS);
4922+
}
4923+
48464924
SDValue AMDGPUTargetLowering::performSelectCombine(SDNode *N,
48474925
DAGCombinerInfo &DCI) const {
48484926
if (SDValue Folded = foldFreeOpFromSelect(DCI, SDValue(N, 0)))
48494927
return Folded;
48504928

4929+
// Try to fold CMP + SELECT patterns with shared constants (both FP and
4930+
// integer).
4931+
if (SDValue Folded = foldCmpSelectWithSharedConstant(N, DCI, Subtarget))
4932+
return Folded;
4933+
48514934
SDValue Cond = N->getOperand(0);
48524935
if (Cond.getOpcode() != ISD::SETCC)
48534936
return SDValue();

0 commit comments

Comments
 (0)