Skip to content

Commit 46786e7

Browse files
committed
Fix 64-bit ashr scalarisation of and for fold int 32-bit shift
Factor shift reducing combine logic into one function as it was applied in all three shift combine functions.
1 parent 573adfe commit 46786e7

File tree

3 files changed

+155
-250
lines changed

3 files changed

+155
-250
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 39 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -725,18 +725,6 @@ static bool selectSupportsSourceMods(const SDNode *N) {
725725
return N->getValueType(0) == MVT::f32;
726726
}
727727

728-
LLVM_READONLY
729-
static bool buildVectorSupportsSourceMods(const SDNode *N) {
730-
if (N->getValueType(0) != MVT::v2f32)
731-
return true;
732-
733-
if (N->getOperand(0)->getOpcode() != ISD::SELECT ||
734-
N->getOperand(1)->getOpcode() != ISD::SELECT)
735-
return true;
736-
737-
return false;
738-
}
739-
740728
// Most FP instructions support source modifiers, but this could be refined
741729
// slightly.
742730
LLVM_READONLY
@@ -770,8 +758,6 @@ static bool hasSourceMods(const SDNode *N) {
770758
return true;
771759
}
772760
}
773-
case ISD::BUILD_VECTOR:
774-
return buildVectorSupportsSourceMods(N);
775761
case ISD::SELECT:
776762
return selectSupportsSourceMods(N);
777763
default:
@@ -4068,15 +4054,21 @@ SDValue AMDGPUTargetLowering::splitBinaryBitConstantOpImpl(
40684054
return DAG.getNode(ISD::BITCAST, SL, MVT::i64, Vec);
40694055
}
40704056

4071-
SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
4072-
DAGCombinerInfo &DCI) const {
4073-
EVT VT = N->getValueType(0);
4074-
SDValue LHS = N->getOperand(0);
4075-
SDValue RHS = N->getOperand(1);
4076-
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
4077-
SDLoc SL(N);
4078-
SelectionDAG &DAG = DCI.DAG;
4057+
// Part of the shift combines is to optimise for the case where its possible
4058+
// to reduce e.g shl64 to shl32 if shift range is [63-32]. This
4059+
// transforms: DST = shl i64 X, Y to [0, srl i32 X, (Y & 31) ]. The
4060+
// '&' is then elided by ISel. The vector code for this was being
4061+
// completely scalarised by the vector legalizer, but when v2i32 is
4062+
// legal the vector legaliser only partially scalarises the
4063+
// vector operations and the and is not elided. This function
4064+
// scalarises the AND for this optimisation case.
4065+
static SDValue getShiftForReduction(unsigned ShiftOpc, SDValue LHS, SDValue RHS,
4066+
SelectionDAG &DAG) {
4067+
assert(
4068+
(ShiftOpc == ISD::SRA || ShiftOpc == ISD::SRL || ShiftOpc == ISD::SHL) &&
4069+
"Expected shift Opcode.");
40794070

4071+
SDLoc SL = SDLoc(RHS);
40804072
if (RHS->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
40814073
SDValue VAND = RHS.getOperand(0);
40824074
if (ConstantSDNode *CRRHS = dyn_cast<ConstantSDNode>(RHS->getOperand(1))) {
@@ -4085,15 +4077,6 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40854077
SDValue LHSAND = VAND.getOperand(0);
40864078
SDValue RHSAND = VAND.getOperand(1);
40874079
if (RHSAND->getOpcode() == ISD::BUILD_VECTOR) {
4088-
// Part of shlcombine is to optimise for the case where its possible
4089-
// to reduce shl64 to shl32 if shift range is [63-32]. This
4090-
// transforms: DST = shl i64 X, Y to [0, shl i32 X, (Y & 31) ]. The
4091-
// '&' is then elided by ISel. The vector code for this was being
4092-
// completely scalarised by the vector legalizer, but now v2i32 is
4093-
// made legal the vector legaliser only partially scalarises the
4094-
// vector operations and the and was not elided. This check enables us
4095-
// to locate and scalarise the v2i32 and and re-enable ISel to elide
4096-
// the and instruction.
40974080
ConstantSDNode *CANDL =
40984081
dyn_cast<ConstantSDNode>(RHSAND->getOperand(0));
40994082
ConstantSDNode *CANDR =
@@ -4107,19 +4090,33 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
41074090
LHSAND, Zero);
41084091
SDValue Hi =
41094092
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, LHSAND, One);
4110-
SDValue LoAnd =
4111-
DAG.getNode(ISD::AND, SL, MVT::i32, Lo, RHSAND->getOperand(0));
4112-
SDValue HiAnd =
4113-
DAG.getNode(ISD::AND, SL, MVT::i32, Hi, RHSAND->getOperand(0));
4093+
SDValue AndMask = DAG.getConstant(0x1f, SL, MVT::i32);
4094+
SDValue LoAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Lo, AndMask);
4095+
SDValue HiAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Hi, AndMask);
41144096
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
41154097
if (AndIndex == 0 || AndIndex == 1)
4116-
return DAG.getNode(ISD::SHL, SL, MVT::i32, Trunc,
4117-
AndIndex == 0 ? LoAnd : HiAnd, N->getFlags());
4098+
return DAG.getNode(ShiftOpc, SL, MVT::i32, Trunc,
4099+
AndIndex == 0 ? LoAnd : HiAnd,
4100+
RHS->getFlags());
41184101
}
41194102
}
41204103
}
41214104
}
41224105
}
4106+
return SDValue();
4107+
}
4108+
4109+
SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
4110+
DAGCombinerInfo &DCI) const {
4111+
EVT VT = N->getValueType(0);
4112+
SDValue LHS = N->getOperand(0);
4113+
SDValue RHS = N->getOperand(1);
4114+
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
4115+
SDLoc SL(N);
4116+
SelectionDAG &DAG = DCI.DAG;
4117+
4118+
if (SDValue SS = getShiftForReduction(ISD::SHL, LHS, RHS, DAG))
4119+
return SS;
41234120

41244121
unsigned RHSVal;
41254122
if (CRHS) {
@@ -4221,6 +4218,9 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
42214218
SelectionDAG &DAG = DCI.DAG;
42224219
SDLoc SL(N);
42234220

4221+
if (SDValue SS = getShiftForReduction(ISD::SRA, LHS, RHS, DAG))
4222+
return SS;
4223+
42244224
if (VT.getScalarType() != MVT::i64)
42254225
return SDValue();
42264226

@@ -4313,52 +4313,6 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
43134313
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
43144314
}
43154315

4316-
static SDValue getScalarisedShift(SDValue LHS, SDValue RHS, SelectionDAG &DAG) {
4317-
SDLoc SL = SDLoc(RHS);
4318-
if (RHS->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
4319-
SDValue VAND = RHS.getOperand(0);
4320-
if (ConstantSDNode *CRRHS = dyn_cast<ConstantSDNode>(RHS->getOperand(1))) {
4321-
uint64_t AndIndex = RHS->getConstantOperandVal(1);
4322-
if (VAND->getOpcode() == ISD::AND && CRRHS) {
4323-
SDValue LHSAND = VAND.getOperand(0);
4324-
SDValue RHSAND = VAND.getOperand(1);
4325-
if (RHSAND->getOpcode() == ISD::BUILD_VECTOR) {
4326-
// Part of srlcombine is to optimise for the case where its possible
4327-
// to reduce shl64 to shl32 if shift range is [63-32]. This
4328-
// transforms: DST = shl i64 X, Y to [0, srl i32 X, (Y & 31) ]. The
4329-
// '&' is then elided by ISel. The vector code for this was being
4330-
// completely scalarised by the vector legalizer, but now v2i32 is
4331-
// made legal the vector legaliser only partially scalarises the
4332-
// vector operations and the and was not elided. This check enables us
4333-
// to locate and scalarise the v2i32 and and re-enable ISel to elide
4334-
// the and instruction.
4335-
ConstantSDNode *CANDL =
4336-
dyn_cast<ConstantSDNode>(RHSAND->getOperand(0));
4337-
ConstantSDNode *CANDR =
4338-
dyn_cast<ConstantSDNode>(RHSAND->getOperand(1));
4339-
if (CANDL && CANDR && RHSAND->getConstantOperandVal(0) == 0x1f &&
4340-
RHSAND->getConstantOperandVal(1) == 0x1f) {
4341-
// Get the non-const AND operands and produce scalar AND
4342-
const SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
4343-
const SDValue One = DAG.getConstant(1, SL, MVT::i32);
4344-
SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32,
4345-
LHSAND, Zero);
4346-
SDValue Hi =
4347-
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, LHSAND, One);
4348-
SDValue AndMask = DAG.getConstant(0x1f, SL, MVT::i32);
4349-
SDValue LoAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Lo, AndMask);
4350-
SDValue HiAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Hi, AndMask);
4351-
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
4352-
if (AndIndex == 0 || AndIndex == 1)
4353-
return DAG.getNode(ISD::SRL, SL, MVT::i32, Trunc,
4354-
AndIndex == 0 ? LoAnd : HiAnd, RHS->getFlags());
4355-
}
4356-
}
4357-
}
4358-
}
4359-
}
4360-
return SDValue();
4361-
}
43624316

43634317
SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
43644318
DAGCombinerInfo &DCI) const {
@@ -4370,7 +4324,8 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
43704324
SDLoc SL(N);
43714325
unsigned RHSVal;
43724326

4373-
4327+
if (SDValue SS = getShiftForReduction(ISD::SRL, LHS, RHS, DAG))
4328+
return SS;
43744329

43754330
if (CRHS) {
43764331
RHSVal = CRHS->getZExtValue();

0 commit comments

Comments
 (0)