Skip to content

Commit 332c8b9

Browse files
committed
Fix 64-bit ashr scalarisation of and for fold int 32-bit shift
Factor shift reducing combine logic into one function as it was applied in all three shift combine functions.
1 parent a28f920 commit 332c8b9

File tree

3 files changed

+155
-250
lines changed

3 files changed

+155
-250
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 39 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -726,18 +726,6 @@ static bool selectSupportsSourceMods(const SDNode *N) {
726726
return N->getValueType(0) == MVT::f32;
727727
}
728728

729-
LLVM_READONLY
730-
static bool buildVectorSupportsSourceMods(const SDNode *N) {
731-
if (N->getValueType(0) != MVT::v2f32)
732-
return true;
733-
734-
if (N->getOperand(0)->getOpcode() != ISD::SELECT ||
735-
N->getOperand(1)->getOpcode() != ISD::SELECT)
736-
return true;
737-
738-
return false;
739-
}
740-
741729
// Most FP instructions support source modifiers, but this could be refined
742730
// slightly.
743731
LLVM_READONLY
@@ -771,8 +759,6 @@ static bool hasSourceMods(const SDNode *N) {
771759
return true;
772760
}
773761
}
774-
case ISD::BUILD_VECTOR:
775-
return buildVectorSupportsSourceMods(N);
776762
case ISD::SELECT:
777763
return selectSupportsSourceMods(N);
778764
default:
@@ -4069,15 +4055,21 @@ SDValue AMDGPUTargetLowering::splitBinaryBitConstantOpImpl(
40694055
return DAG.getNode(ISD::BITCAST, SL, MVT::i64, Vec);
40704056
}
40714057

4072-
SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
4073-
DAGCombinerInfo &DCI) const {
4074-
EVT VT = N->getValueType(0);
4075-
SDValue LHS = N->getOperand(0);
4076-
SDValue RHS = N->getOperand(1);
4077-
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
4078-
SDLoc SL(N);
4079-
SelectionDAG &DAG = DCI.DAG;
4058+
// Part of the shift combines is to optimise for the case where its possible
4059+
// to reduce e.g shl64 to shl32 if shift range is [63-32]. This
4060+
// transforms: DST = shl i64 X, Y to [0, srl i32 X, (Y & 31) ]. The
4061+
// '&' is then elided by ISel. The vector code for this was being
4062+
// completely scalarised by the vector legalizer, but when v2i32 is
4063+
// legal the vector legaliser only partially scalarises the
4064+
// vector operations and the and is not elided. This function
4065+
// scalarises the AND for this optimisation case.
4066+
static SDValue getShiftForReduction(unsigned ShiftOpc, SDValue LHS, SDValue RHS,
4067+
SelectionDAG &DAG) {
4068+
assert(
4069+
(ShiftOpc == ISD::SRA || ShiftOpc == ISD::SRL || ShiftOpc == ISD::SHL) &&
4070+
"Expected shift Opcode.");
40804071

4072+
SDLoc SL = SDLoc(RHS);
40814073
if (RHS->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
40824074
SDValue VAND = RHS.getOperand(0);
40834075
if (ConstantSDNode *CRRHS = dyn_cast<ConstantSDNode>(RHS->getOperand(1))) {
@@ -4086,15 +4078,6 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40864078
SDValue LHSAND = VAND.getOperand(0);
40874079
SDValue RHSAND = VAND.getOperand(1);
40884080
if (RHSAND->getOpcode() == ISD::BUILD_VECTOR) {
4089-
// Part of shlcombine is to optimise for the case where its possible
4090-
// to reduce shl64 to shl32 if shift range is [63-32]. This
4091-
// transforms: DST = shl i64 X, Y to [0, shl i32 X, (Y & 31) ]. The
4092-
// '&' is then elided by ISel. The vector code for this was being
4093-
// completely scalarised by the vector legalizer, but now v2i32 is
4094-
// made legal the vector legaliser only partially scalarises the
4095-
// vector operations and the and was not elided. This check enables us
4096-
// to locate and scalarise the v2i32 and and re-enable ISel to elide
4097-
// the and instruction.
40984081
ConstantSDNode *CANDL =
40994082
dyn_cast<ConstantSDNode>(RHSAND->getOperand(0));
41004083
ConstantSDNode *CANDR =
@@ -4108,19 +4091,33 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
41084091
LHSAND, Zero);
41094092
SDValue Hi =
41104093
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, LHSAND, One);
4111-
SDValue LoAnd =
4112-
DAG.getNode(ISD::AND, SL, MVT::i32, Lo, RHSAND->getOperand(0));
4113-
SDValue HiAnd =
4114-
DAG.getNode(ISD::AND, SL, MVT::i32, Hi, RHSAND->getOperand(0));
4094+
SDValue AndMask = DAG.getConstant(0x1f, SL, MVT::i32);
4095+
SDValue LoAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Lo, AndMask);
4096+
SDValue HiAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Hi, AndMask);
41154097
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
41164098
if (AndIndex == 0 || AndIndex == 1)
4117-
return DAG.getNode(ISD::SHL, SL, MVT::i32, Trunc,
4118-
AndIndex == 0 ? LoAnd : HiAnd, N->getFlags());
4099+
return DAG.getNode(ShiftOpc, SL, MVT::i32, Trunc,
4100+
AndIndex == 0 ? LoAnd : HiAnd,
4101+
RHS->getFlags());
41194102
}
41204103
}
41214104
}
41224105
}
41234106
}
4107+
return SDValue();
4108+
}
4109+
4110+
SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
4111+
DAGCombinerInfo &DCI) const {
4112+
EVT VT = N->getValueType(0);
4113+
SDValue LHS = N->getOperand(0);
4114+
SDValue RHS = N->getOperand(1);
4115+
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
4116+
SDLoc SL(N);
4117+
SelectionDAG &DAG = DCI.DAG;
4118+
4119+
if (SDValue SS = getShiftForReduction(ISD::SHL, LHS, RHS, DAG))
4120+
return SS;
41244121

41254122
unsigned RHSVal;
41264123
if (CRHS) {
@@ -4222,6 +4219,9 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
42224219
SelectionDAG &DAG = DCI.DAG;
42234220
SDLoc SL(N);
42244221

4222+
if (SDValue SS = getShiftForReduction(ISD::SRA, LHS, RHS, DAG))
4223+
return SS;
4224+
42254225
if (VT.getScalarType() != MVT::i64)
42264226
return SDValue();
42274227

@@ -4314,52 +4314,6 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
43144314
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
43154315
}
43164316

4317-
static SDValue getScalarisedShift(SDValue LHS, SDValue RHS, SelectionDAG &DAG) {
4318-
SDLoc SL = SDLoc(RHS);
4319-
if (RHS->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
4320-
SDValue VAND = RHS.getOperand(0);
4321-
if (ConstantSDNode *CRRHS = dyn_cast<ConstantSDNode>(RHS->getOperand(1))) {
4322-
uint64_t AndIndex = RHS->getConstantOperandVal(1);
4323-
if (VAND->getOpcode() == ISD::AND && CRRHS) {
4324-
SDValue LHSAND = VAND.getOperand(0);
4325-
SDValue RHSAND = VAND.getOperand(1);
4326-
if (RHSAND->getOpcode() == ISD::BUILD_VECTOR) {
4327-
// Part of srlcombine is to optimise for the case where its possible
4328-
// to reduce shl64 to shl32 if shift range is [63-32]. This
4329-
// transforms: DST = shl i64 X, Y to [0, srl i32 X, (Y & 31) ]. The
4330-
// '&' is then elided by ISel. The vector code for this was being
4331-
// completely scalarised by the vector legalizer, but now v2i32 is
4332-
// made legal the vector legaliser only partially scalarises the
4333-
// vector operations and the and was not elided. This check enables us
4334-
// to locate and scalarise the v2i32 and and re-enable ISel to elide
4335-
// the and instruction.
4336-
ConstantSDNode *CANDL =
4337-
dyn_cast<ConstantSDNode>(RHSAND->getOperand(0));
4338-
ConstantSDNode *CANDR =
4339-
dyn_cast<ConstantSDNode>(RHSAND->getOperand(1));
4340-
if (CANDL && CANDR && RHSAND->getConstantOperandVal(0) == 0x1f &&
4341-
RHSAND->getConstantOperandVal(1) == 0x1f) {
4342-
// Get the non-const AND operands and produce scalar AND
4343-
const SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
4344-
const SDValue One = DAG.getConstant(1, SL, MVT::i32);
4345-
SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32,
4346-
LHSAND, Zero);
4347-
SDValue Hi =
4348-
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, LHSAND, One);
4349-
SDValue AndMask = DAG.getConstant(0x1f, SL, MVT::i32);
4350-
SDValue LoAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Lo, AndMask);
4351-
SDValue HiAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Hi, AndMask);
4352-
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
4353-
if (AndIndex == 0 || AndIndex == 1)
4354-
return DAG.getNode(ISD::SRL, SL, MVT::i32, Trunc,
4355-
AndIndex == 0 ? LoAnd : HiAnd, RHS->getFlags());
4356-
}
4357-
}
4358-
}
4359-
}
4360-
}
4361-
return SDValue();
4362-
}
43634317

43644318
SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
43654319
DAGCombinerInfo &DCI) const {
@@ -4371,7 +4325,8 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
43714325
SDLoc SL(N);
43724326
unsigned RHSVal;
43734327

4374-
4328+
if (SDValue SS = getShiftForReduction(ISD::SRL, LHS, RHS, DAG))
4329+
return SS;
43754330

43764331
if (CRHS) {
43774332
RHSVal = CRHS->getZExtValue();

0 commit comments

Comments
 (0)