Skip to content

Commit 2561e72

Browse files
committed
Work to fix regressions in integer select srcmod generation when v2i32
or/xor/and are legalized.
1 parent 7b31e62 commit 2561e72

File tree

5 files changed

+158
-152
lines changed

5 files changed

+158
-152
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3225,29 +3225,51 @@ bool AMDGPUDAGToDAGISel::SelectVOP3ModsImpl(SDValue In, SDValue &Src,
32253225
if (IsCanonicalizing)
32263226
return true;
32273227

3228-
unsigned Opc = Src->getOpcode();
3228+
// v2i32 xor/or/and are legal. A vselect using these instructions as operands
3229+
// is scalarised into two selects with EXTRACT_VECTOR_ELT operands. Peek
3230+
// through the extract to the bitwise op.
3231+
SDValue PeekSrc =
3232+
Src->getOpcode() == ISD::EXTRACT_VECTOR_ELT ? Src->getOperand(0) : Src;
3233+
// Convert various sign-bit masks to src mods. Currently disabled for 16-bit
3234+
// types as the codegen replaces the operand without adding a srcmod.
3235+
// This is intentionally finding the cases where we are performing float neg
3236+
// and abs on int types, the goal is not to obtain two's complement neg or
3237+
// abs.
3238+
// TODO: Add 16-bit support.
3239+
unsigned Opc = PeekSrc.getOpcode();
32293240
EVT VT = Src.getValueType();
32303241
if ((Opc != ISD::AND && Opc != ISD::OR && Opc != ISD::XOR) ||
3231-
(VT != MVT::i32 && VT != MVT::i64))
3242+
(VT != MVT::i32 && VT != MVT::v2i32 && VT != MVT::i64))
32323243
return true;
32333244

3234-
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(Src->getOperand(1));
3245+
ConstantSDNode *CRHS =
3246+
isConstOrConstSplat(PeekSrc ? PeekSrc->getOperand(1) : Src->getOperand(1));
32353247
if (!CRHS)
32363248
return true;
32373249

3250+
auto ReplaceSrc = [&]() -> SDValue {
3251+
if (Src->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
3252+
SDValue LHS = PeekSrc->getOperand(0);
3253+
SDValue Index = Src->getOperand(1);
3254+
return CurDAG->getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(Src),
3255+
Src.getValueType(), LHS, Index);
3256+
}
3257+
return PeekSrc.getOperand(0);
3258+
};
3259+
32383260
// Recognise (xor a, 0x80000000) as NEG SrcMod.
32393261
// Recognise (and a, 0x7fffffff) as ABS SrcMod.
32403262
// Recognise (or a, 0x80000000) as NEG+ABS SrcModifiers.
32413263
if (Opc == ISD::XOR && CRHS->getAPIntValue().isSignMask()) {
32423264
Mods |= SISrcMods::NEG;
3243-
Src = Src.getOperand(0);
3265+
Src = ReplaceSrc();
32443266
} else if (Opc == ISD::AND && AllowAbs &&
32453267
CRHS->getAPIntValue().isMaxSignedValue()) {
32463268
Mods |= SISrcMods::ABS;
3247-
Src = Src.getOperand(0);
3269+
Src = ReplaceSrc();
32483270
} else if (Opc == ISD::OR && AllowAbs && CRHS->getAPIntValue().isSignMask()) {
32493271
Mods |= SISrcMods::ABS | SISrcMods::NEG;
3250-
Src = Src.getOperand(0);
3272+
Src = ReplaceSrc();
32513273
}
32523274

32533275
return true;

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 64 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -719,18 +719,6 @@ static bool selectSupportsSourceMods(const SDNode *N) {
719719
return N->getValueType(0) == MVT::f32;
720720
}
721721

722-
LLVM_READONLY
723-
static bool buildVectorSupportsSourceMods(const SDNode *N) {
724-
if (N->getValueType(0) != MVT::v2f32)
725-
return true;
726-
727-
if (N->getOperand(0)->getOpcode() != ISD::SELECT ||
728-
N->getOperand(1)->getOpcode() != ISD::SELECT)
729-
return true;
730-
731-
return false;
732-
}
733-
734722
// Most FP instructions support source modifiers, but this could be refined
735723
// slightly.
736724
LLVM_READONLY
@@ -764,8 +752,6 @@ static bool hasSourceMods(const SDNode *N) {
764752
return true;
765753
}
766754
}
767-
case ISD::BUILD_VECTOR:
768-
return buildVectorSupportsSourceMods(N);
769755
case ISD::SELECT:
770756
return selectSupportsSourceMods(N);
771757
default:
@@ -4062,6 +4048,59 @@ SDValue AMDGPUTargetLowering::splitBinaryBitConstantOpImpl(
40624048
return DAG.getNode(ISD::BITCAST, SL, MVT::i64, Vec);
40634049
}
40644050

4051+
// Part of the shift combines is to optimise for the case where its possible
4052+
// to reduce e.g shl64 to shl32 if shift range is [63-32]. This
4053+
// transforms: DST = shl i64 X, Y to [0, srl i32 X, (Y & 31) ]. The
4054+
// '&' is then elided by ISel. The vector code for this was being
4055+
// completely scalarised by the vector legalizer, but when v2i32 is
4056+
// legal the vector legaliser only partially scalarises the
4057+
// vector operations and the and is not elided. This function
4058+
// scalarises the AND for this optimisation case.
4059+
static SDValue getShiftForReduction(unsigned ShiftOpc, SDValue LHS, SDValue RHS,
4060+
SelectionDAG &DAG) {
4061+
assert(
4062+
(ShiftOpc == ISD::SRA || ShiftOpc == ISD::SRL || ShiftOpc == ISD::SHL) &&
4063+
"Expected shift Opcode.");
4064+
4065+
SDLoc SL = SDLoc(RHS);
4066+
if (RHS->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
4067+
return SDValue();
4068+
4069+
SDValue VAND = RHS.getOperand(0);
4070+
if (VAND->getOpcode() != ISD::AND)
4071+
return SDValue();
4072+
4073+
ConstantSDNode *CRRHS = dyn_cast<ConstantSDNode>(RHS->getOperand(1));
4074+
if (!CRRHS)
4075+
return SDValue();
4076+
4077+
SDValue LHSAND = VAND.getOperand(0);
4078+
SDValue RHSAND = VAND.getOperand(1);
4079+
if (RHSAND->getOpcode() != ISD::BUILD_VECTOR)
4080+
return SDValue();
4081+
4082+
ConstantSDNode *CANDL = dyn_cast<ConstantSDNode>(RHSAND->getOperand(0));
4083+
ConstantSDNode *CANDR = dyn_cast<ConstantSDNode>(RHSAND->getOperand(1));
4084+
if (!CANDL || !CANDR || RHSAND->getConstantOperandVal(0) != 0x1f ||
4085+
RHSAND->getConstantOperandVal(1) != 0x1f)
4086+
return SDValue();
4087+
// Get the non-const AND operands and produce scalar AND
4088+
const SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
4089+
const SDValue One = DAG.getConstant(1, SL, MVT::i32);
4090+
SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, LHSAND, Zero);
4091+
SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, LHSAND, One);
4092+
SDValue AndMask = DAG.getConstant(0x1f, SL, MVT::i32);
4093+
SDValue LoAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Lo, AndMask);
4094+
SDValue HiAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Hi, AndMask);
4095+
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
4096+
uint64_t AndIndex = RHS->getConstantOperandVal(1);
4097+
if (AndIndex == 0 || AndIndex == 1)
4098+
return DAG.getNode(ShiftOpc, SL, MVT::i32, Trunc,
4099+
AndIndex == 0 ? LoAnd : HiAnd, RHS->getFlags());
4100+
4101+
return SDValue();
4102+
}
4103+
40654104
SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40664105
DAGCombinerInfo &DCI) const {
40674106
EVT VT = N->getValueType(0);
@@ -4071,49 +4110,8 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40714110
SDLoc SL(N);
40724111
SelectionDAG &DAG = DCI.DAG;
40734112

4074-
if (RHS->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
4075-
SDValue VAND = RHS.getOperand(0);
4076-
if (ConstantSDNode *CRRHS = dyn_cast<ConstantSDNode>(RHS->getOperand(1))) {
4077-
uint64_t AndIndex = RHS->getConstantOperandVal(1);
4078-
if (VAND->getOpcode() == ISD::AND && CRRHS) {
4079-
SDValue LHSAND = VAND.getOperand(0);
4080-
SDValue RHSAND = VAND.getOperand(1);
4081-
if (RHSAND->getOpcode() == ISD::BUILD_VECTOR) {
4082-
// Part of shlcombine is to optimise for the case where its possible
4083-
// to reduce shl64 to shl32 if shift range is [63-32]. This
4084-
// transforms: DST = shl i64 X, Y to [0, shl i32 X, (Y & 31) ]. The
4085-
// '&' is then elided by ISel. The vector code for this was being
4086-
// completely scalarised by the vector legalizer, but now v2i32 is
4087-
// made legal the vector legaliser only partially scalarises the
4088-
// vector operations and the and was not elided. This check enables us
4089-
// to locate and scalarise the v2i32 and and re-enable ISel to elide
4090-
// the and instruction.
4091-
ConstantSDNode *CANDL =
4092-
dyn_cast<ConstantSDNode>(RHSAND->getOperand(0));
4093-
ConstantSDNode *CANDR =
4094-
dyn_cast<ConstantSDNode>(RHSAND->getOperand(1));
4095-
if (CANDL && CANDR && RHSAND->getConstantOperandVal(0) == 0x1f &&
4096-
RHSAND->getConstantOperandVal(1) == 0x1f) {
4097-
// Get the non-const AND operands and produce scalar AND
4098-
const SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
4099-
const SDValue One = DAG.getConstant(1, SL, MVT::i32);
4100-
SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32,
4101-
LHSAND, Zero);
4102-
SDValue Hi =
4103-
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, LHSAND, One);
4104-
SDValue LoAnd =
4105-
DAG.getNode(ISD::AND, SL, MVT::i32, Lo, RHSAND->getOperand(0));
4106-
SDValue HiAnd =
4107-
DAG.getNode(ISD::AND, SL, MVT::i32, Hi, RHSAND->getOperand(0));
4108-
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
4109-
if (AndIndex == 0 || AndIndex == 1)
4110-
return DAG.getNode(ISD::SHL, SL, MVT::i32, Trunc,
4111-
AndIndex == 0 ? LoAnd : HiAnd, N->getFlags());
4112-
}
4113-
}
4114-
}
4115-
}
4116-
}
4113+
if (SDValue SS = getShiftForReduction(ISD::SHL, LHS, RHS, DAG))
4114+
return SS;
41174115

41184116
unsigned RHSVal;
41194117
if (CRHS) {
@@ -4215,6 +4213,9 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
42154213
SelectionDAG &DAG = DCI.DAG;
42164214
SDLoc SL(N);
42174215

4216+
if (SDValue SS = getShiftForReduction(ISD::SRA, LHS, RHS, DAG))
4217+
return SS;
4218+
42184219
if (VT.getScalarType() != MVT::i64)
42194220
return SDValue();
42204221

@@ -4245,12 +4246,12 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
42454246
(ElementType.getSizeInBits() - 1)) {
42464247
ShiftAmt = ShiftFullAmt;
42474248
} else {
4248-
SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
4249+
SDValue TruncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
42494250
const SDValue ShiftMask =
42504251
DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
42514252
// This AND instruction will clamp out of bounds shift values.
42524253
// It will also be removed during later instruction selection.
4253-
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
4254+
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, TruncShiftAmt, ShiftMask);
42544255
}
42554256

42564257
EVT ConcatType;
@@ -4317,48 +4318,8 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
43174318
SDLoc SL(N);
43184319
unsigned RHSVal;
43194320

4320-
if (RHS->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
4321-
SDValue VAND = RHS.getOperand(0);
4322-
if (ConstantSDNode *CRRHS = dyn_cast<ConstantSDNode>(RHS->getOperand(1))) {
4323-
uint64_t AndIndex = RHS->getConstantOperandVal(1);
4324-
if (VAND->getOpcode() == ISD::AND && CRRHS) {
4325-
SDValue LHSAND = VAND.getOperand(0);
4326-
SDValue RHSAND = VAND.getOperand(1);
4327-
if (RHSAND->getOpcode() == ISD::BUILD_VECTOR) {
4328-
// Part of srlcombine is to optimise for the case where its possible
4329-
// to reduce shl64 to shl32 if shift range is [63-32]. This
4330-
// transforms: DST = shl i64 X, Y to [0, srl i32 X, (Y & 31) ]. The
4331-
// '&' is then elided by ISel. The vector code for this was being
4332-
// completely scalarised by the vector legalizer, but now v2i32 is
4333-
// made legal the vector legaliser only partially scalarises the
4334-
// vector operations and the and was not elided. This check enables us
4335-
// to locate and scalarise the v2i32 and and re-enable ISel to elide
4336-
// the and instruction.
4337-
ConstantSDNode *CANDL =
4338-
dyn_cast<ConstantSDNode>(RHSAND->getOperand(0));
4339-
ConstantSDNode *CANDR =
4340-
dyn_cast<ConstantSDNode>(RHSAND->getOperand(1));
4341-
if (CANDL && CANDR && RHSAND->getConstantOperandVal(0) == 0x1f &&
4342-
RHSAND->getConstantOperandVal(1) == 0x1f) {
4343-
// Get the non-const AND operands and produce scalar AND
4344-
const SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
4345-
const SDValue One = DAG.getConstant(1, SL, MVT::i32);
4346-
SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32,
4347-
LHSAND, Zero);
4348-
SDValue Hi =
4349-
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, LHSAND, One);
4350-
SDValue AndMask = DAG.getConstant(0x1f, SL, MVT::i32);
4351-
SDValue LoAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Lo, AndMask);
4352-
SDValue HiAnd = DAG.getNode(ISD::AND, SL, MVT::i32, Hi, AndMask);
4353-
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
4354-
if (AndIndex == 0 || AndIndex == 1)
4355-
return DAG.getNode(ISD::SRL, SL, MVT::i32, Trunc,
4356-
AndIndex == 0 ? LoAnd : HiAnd, N->getFlags());
4357-
}
4358-
}
4359-
}
4360-
}
4361-
}
4321+
if (SDValue SS = getShiftForReduction(ISD::SRL, LHS, RHS, DAG))
4322+
return SS;
43624323

43634324
if (CRHS) {
43644325
RHSVal = CRHS->getZExtValue();
@@ -4873,8 +4834,8 @@ AMDGPUTargetLowering::foldFreeOpFromSelect(TargetLowering::DAGCombinerInfo &DCI,
48734834
if (!AMDGPUTargetLowering::allUsesHaveSourceMods(N.getNode()))
48744835
return SDValue();
48754836

4876-
return distributeOpThroughSelect(DCI, LHS.getOpcode(), SDLoc(N), Cond, LHS,
4877-
RHS);
4837+
return distributeOpThroughSelect(DCI, LHS.getOpcode(),
4838+
SDLoc(N), Cond, LHS, RHS);
48784839
}
48794840

48804841
bool Inv = false;

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13526,34 +13526,51 @@ SDValue SITargetLowering::performXorCombine(SDNode *N,
1352613526
SDValue LHS = N->getOperand(0);
1352713527
SDValue RHS = N->getOperand(1);
1352813528

13529-
const ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
13529+
const ConstantSDNode *CRHS = isConstOrConstSplat(RHS);
1353013530

1353113531
if (CRHS && VT == MVT::i64) {
1353213532
if (SDValue Split =
1353313533
splitBinaryBitConstantOp(DCI, SDLoc(N), ISD::XOR, LHS, CRHS))
1353413534
return Split;
1353513535
}
1353613536

13537+
// v2i32 (xor (vselect cc, x, y), K) ->
13538+
// (v2i32 svelect cc, (xor x, K), (xor y, K)) This enables the xor to be
13539+
// replaced with source modifiers when the select is lowered to CNDMASK.
13540+
unsigned Opc = LHS.getOpcode();
13541+
if (((Opc == ISD::VSELECT && VT == MVT::v2i32) ||
13542+
(Opc == ISD::SELECT && VT == MVT::i64)) &&
13543+
CRHS && CRHS->getAPIntValue().isSignMask()) {
13544+
SDValue CC = LHS->getOperand(0);
13545+
SDValue TRUE = LHS->getOperand(1);
13546+
SDValue FALSE = LHS->getOperand(2);
13547+
SDValue XTrue = DAG.getNode(ISD::XOR, SDLoc(N), VT, TRUE, RHS);
13548+
SDValue XFalse = DAG.getNode(ISD::XOR, SDLoc(N), VT, FALSE, RHS);
13549+
SDValue XSelect =
13550+
DAG.getNode(ISD::VSELECT, SDLoc(N), VT, CC, XTrue, XFalse);
13551+
return XSelect;
13552+
}
13553+
1353713554
// Make sure to apply the 64-bit constant splitting fold before trying to fold
1353813555
// fneg-like xors into 64-bit select.
13539-
// if (LHS.getOpcode() == ISD::SELECT && VT == MVT::i32) {
13540-
// // This looks like an fneg, try to fold as a source modifier.
13541-
// if (CRHS && CRHS->getAPIntValue().isSignMask() &&
13542-
// shouldFoldFNegIntoSrc(N, LHS)) {
13543-
// // xor (select c, a, b), 0x80000000 ->
13544-
// // bitcast (select c, (fneg (bitcast a)), (fneg (bitcast b)))
13545-
// SDLoc DL(N);
13546-
// SDValue CastLHS =
13547-
// DAG.getNode(ISD::BITCAST, DL, MVT::f32, LHS->getOperand(1));
13548-
// SDValue CastRHS =
13549-
// DAG.getNode(ISD::BITCAST, DL, MVT::f32, LHS->getOperand(2));
13550-
// SDValue FNegLHS = DAG.getNode(ISD::FNEG, DL, MVT::f32, CastLHS);
13551-
// SDValue FNegRHS = DAG.getNode(ISD::FNEG, DL, MVT::f32, CastRHS);
13552-
// SDValue NewSelect = DAG.getNode(ISD::SELECT, DL, MVT::f32,
13553-
// LHS->getOperand(0), FNegLHS, FNegRHS);
13554-
// return DAG.getNode(ISD::BITCAST, DL, VT, NewSelect);
13555-
// }
13556-
// }
13556+
if (LHS.getOpcode() == ISD::SELECT && VT == MVT::i32) {
13557+
// This looks like an fneg, try to fold as a source modifier.
13558+
if (CRHS && CRHS->getAPIntValue().isSignMask() &&
13559+
shouldFoldFNegIntoSrc(N, LHS)) {
13560+
// xor (select c, a, b), 0x80000000 ->
13561+
// bitcast (select c, (fneg (bitcast a)), (fneg (bitcast b)))
13562+
SDLoc DL(N);
13563+
SDValue CastLHS =
13564+
DAG.getNode(ISD::BITCAST, DL, MVT::f32, LHS->getOperand(1));
13565+
SDValue CastRHS =
13566+
DAG.getNode(ISD::BITCAST, DL, MVT::f32, LHS->getOperand(2));
13567+
SDValue FNegLHS = DAG.getNode(ISD::FNEG, DL, MVT::f32, CastLHS);
13568+
SDValue FNegRHS = DAG.getNode(ISD::FNEG, DL, MVT::f32, CastRHS);
13569+
SDValue NewSelect = DAG.getNode(ISD::SELECT, DL, MVT::f32,
13570+
LHS->getOperand(0), FNegLHS, FNegRHS);
13571+
return DAG.getNode(ISD::BITCAST, DL, VT, NewSelect);
13572+
}
13573+
}
1355713574

1355813575
return SDValue();
1355913576
}

0 commit comments

Comments
 (0)