Skip to content

Commit 8550a5c

Browse files
committed
[NVPTX] Use PRMT more widely, and improve folding around this intruction
1 parent aa27d4e commit 8550a5c

File tree

10 files changed

+651
-653
lines changed

10 files changed

+651
-653
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
#include "llvm/Support/CodeGen.h"
5858
#include "llvm/Support/CommandLine.h"
5959
#include "llvm/Support/ErrorHandling.h"
60+
#include "llvm/Support/KnownBits.h"
6061
#include "llvm/Support/NVPTXAddrSpace.h"
6162
#include "llvm/Support/raw_ostream.h"
6263
#include "llvm/Target/TargetMachine.h"
@@ -1070,7 +1071,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10701071
MAKE_CASE(NVPTXISD::StoreV8)
10711072
MAKE_CASE(NVPTXISD::FSHL_CLAMP)
10721073
MAKE_CASE(NVPTXISD::FSHR_CLAMP)
1073-
MAKE_CASE(NVPTXISD::BFE)
10741074
MAKE_CASE(NVPTXISD::BFI)
10751075
MAKE_CASE(NVPTXISD::PRMT)
10761076
MAKE_CASE(NVPTXISD::FCOPYSIGN)
@@ -2145,14 +2145,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
21452145
EVT VectorVT = Vector.getValueType();
21462146

21472147
if (VectorVT == MVT::v4i8) {
2148-
SDValue BFE =
2149-
DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
2150-
{Vector,
2151-
DAG.getNode(ISD::MUL, DL, MVT::i32,
2152-
DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2153-
DAG.getConstant(8, DL, MVT::i32)),
2154-
DAG.getConstant(8, DL, MVT::i32)});
2155-
return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
2148+
SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32,
2149+
DAG.getZExtOrTrunc(Index, DL, MVT::i32),
2150+
DAG.getConstant(0x7770, DL, MVT::i32));
2151+
SDValue PRMT = DAG.getNode(
2152+
NVPTXISD::PRMT, DL, MVT::i32,
2153+
{DAG.getBitcast(MVT::i32, Vector), DAG.getConstant(0, DL, MVT::i32),
2154+
Selector, DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2155+
return DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
21562156
}
21572157

21582158
// Constant index will be matched by tablegen.
@@ -5206,31 +5206,6 @@ static SDValue PerformANDCombine(SDNode *N,
52065206

52075207
SDValue AExt;
52085208

5209-
// Convert BFE-> truncate i16 -> and 255
5210-
// To just BFE-> truncate i16, as the value already has all the bits in the
5211-
// right places.
5212-
if (Val.getOpcode() == ISD::TRUNCATE) {
5213-
SDValue BFE = Val.getOperand(0);
5214-
if (BFE.getOpcode() != NVPTXISD::BFE)
5215-
return SDValue();
5216-
5217-
ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0));
5218-
if (!BFEBits)
5219-
return SDValue();
5220-
uint64_t BFEBitsVal = BFEBits->getZExtValue();
5221-
5222-
ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
5223-
if (!MaskCnst) {
5224-
// Not an AND with a constant
5225-
return SDValue();
5226-
}
5227-
uint64_t MaskVal = MaskCnst->getZExtValue();
5228-
5229-
if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
5230-
return SDValue();
5231-
// If we get here, the AND is unnecessary. Just replace it with the trunc
5232-
DCI.CombineTo(N, Val, false);
5233-
}
52345209
// Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
52355210
if (Val.getOpcode() == ISD::ANY_EXTEND) {
52365211
AExt = Val;
@@ -6334,3 +6309,45 @@ MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
63346309
const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
63356310
return getDataSection();
63366311
}
6312+
6313+
static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
6314+
const SelectionDAG &DAG, unsigned Depth) {
6315+
SDValue A = Op.getOperand(0);
6316+
SDValue B = Op.getOperand(1);
6317+
ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));
6318+
unsigned Mode = Op.getConstantOperandVal(3);
6319+
6320+
if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
6321+
return;
6322+
6323+
KnownBits AKnown = DAG.computeKnownBits(A, Depth);
6324+
KnownBits BKnown = DAG.computeKnownBits(B, Depth);
6325+
6326+
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6327+
KnownBits BitField = BKnown.concat(AKnown);
6328+
6329+
APInt SelectorVal = Selector->getAPIntValue();
6330+
for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {
6331+
APInt Sel = SelectorVal.extractBits(4, I * 4);
6332+
unsigned Idx = Sel.getLoBits(3).getZExtValue();
6333+
unsigned Sign = Sel.getHiBits(1).getZExtValue();
6334+
KnownBits Byte = BitField.extractBits(8, Idx * 8);
6335+
if (Sign)
6336+
Byte = KnownBits::ashr(Byte, 8);
6337+
Known.insertBits(Byte, I * 8);
6338+
}
6339+
}
6340+
6341+
void NVPTXTargetLowering::computeKnownBitsForTargetNode(
6342+
const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
6343+
const SelectionDAG &DAG, unsigned Depth) const {
6344+
Known.resetAll();
6345+
6346+
switch (Op.getOpcode()) {
6347+
case NVPTXISD::PRMT:
6348+
computeKnownBitsForPRMT(Op, Known, DAG, Depth);
6349+
break;
6350+
default:
6351+
break;
6352+
}
6353+
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ enum NodeType : unsigned {
5050
MUL_WIDE_UNSIGNED,
5151
SETP_F16X2,
5252
SETP_BF16X2,
53-
BFE,
5453
BFI,
5554
PRMT,
5655

@@ -272,6 +271,11 @@ class NVPTXTargetLowering : public TargetLowering {
272271
unsigned getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
273272
EVT ToVT) const override;
274273

274+
void computeKnownBitsForTargetNode(const SDValue Op, KnownBits &Known,
275+
const APInt &DemandedElts,
276+
const SelectionDAG &DAG,
277+
unsigned Depth = 0) const override;
278+
275279
private:
276280
const NVPTXSubtarget &STI; // cache the subtarget here
277281
mutable unsigned GlobalUniqueCallSite;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,11 +1359,6 @@ def BREV64 :
13591359
// restriction in PTX?
13601360
//
13611361
// dest and src may be int32 or int64, but start and end are always int32.
1362-
def SDTBFE :
1363-
SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>,
1364-
SDTCisVT<2, i32>, SDTCisVT<3, i32>]>;
1365-
def bfe : SDNode<"NVPTXISD::BFE", SDTBFE>;
1366-
13671362
def SDTBFI :
13681363
SDTypeProfile<1, 4, [SDTCisInt<0>, SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>,
13691364
SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
@@ -1374,22 +1369,13 @@ def SDTPRMT :
13741369
SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
13751370
def prmt : SDNode<"NVPTXISD::PRMT", SDTPRMT>;
13761371

1377-
multiclass BFE<string Instr, ValueType T, RegisterClass RC> {
1372+
multiclass BFE<string Instr, RegisterClass RC> {
13781373
def rrr
1379-
: BasicNVPTXInst<(outs RC:$d),
1380-
(ins RC:$a, B32:$b, B32:$c),
1381-
Instr,
1382-
[(set T:$d, (bfe T:$a, i32:$b, i32:$c))]>;
1374+
: BasicNVPTXInst<(outs RC:$d), (ins RC:$a, B32:$b, B32:$c), Instr>;
13831375
def rri
1384-
: BasicNVPTXInst<(outs RC:$d),
1385-
(ins RC:$a, B32:$b, i32imm:$c),
1386-
Instr,
1387-
[(set T:$d, (bfe T:$a, i32:$b, imm:$c))]>;
1376+
: BasicNVPTXInst<(outs RC:$d), (ins RC:$a, B32:$b, i32imm:$c), Instr>;
13881377
def rii
1389-
: BasicNVPTXInst<(outs RC:$d),
1390-
(ins RC:$a, i32imm:$b, i32imm:$c),
1391-
Instr,
1392-
[(set T:$d, (bfe T:$a, imm:$b, imm:$c))]>;
1378+
: BasicNVPTXInst<(outs RC:$d), (ins RC:$a, i32imm:$b, i32imm:$c), Instr>;
13931379
}
13941380

13951381
multiclass BFI<string Instr, ValueType T, RegisterClass RC, Operand ImmCls> {
@@ -1434,10 +1420,10 @@ let hasSideEffects = false in {
14341420
// the same patterns, so the first one wins. Having unsigned byte extraction
14351421
// has the benefit of always having zero in unused bits, which makes some
14361422
// optimizations easier (e.g. no need to mask them).
1437-
defm BFE_U32 : BFE<"bfe.u32", i32, B32>;
1438-
defm BFE_S32 : BFE<"bfe.s32", i32, B32>;
1439-
defm BFE_U64 : BFE<"bfe.u64", i64, B64>;
1440-
defm BFE_S64 : BFE<"bfe.s64", i64, B64>;
1423+
defm BFE_U32 : BFE<"bfe.u32", B32>;
1424+
defm BFE_S32 : BFE<"bfe.s32", B32>;
1425+
defm BFE_U64 : BFE<"bfe.u64", B64>;
1426+
defm BFE_S64 : BFE<"bfe.s64", B64>;
14411427

14421428
defm BFI_B32 : BFI<"bfi.b32", i32, B32, i32imm>;
14431429
defm BFI_B64 : BFI<"bfi.b64", i64, B64, i64imm>;
@@ -1474,19 +1460,26 @@ def : Pat<(fshr i32:$hi, i32:$lo, (shl i32:$amt, (i32 3))),
14741460
(PRMT_B32rrr $lo, $hi, $amt, PrmtF4E)>;
14751461

14761462

1463+
def byte_extract_prmt : ImmLeaf<i32, [{
1464+
return (Imm == 0x7770) || (Imm == 0x7771) || (Imm == 0x7772) || (Imm == 0x7773);
1465+
}]>;
1466+
1467+
def to_sign_extend_selector : SDNodeXForm<imm, [{
1468+
const APInt &V = N->getAPIntValue();
1469+
const APInt B = V.trunc(4);
1470+
const APInt BSext = B | 8;
1471+
const APInt R = BSext.concat(BSext).concat(BSext).concat(B).zext(32);
1472+
return CurDAG->getTargetConstant(R, SDLoc(N), MVT::i32);
1473+
}]>;
1474+
1475+
14771476
// byte extraction + signed/unsigned extension to i32.
1478-
def : Pat<(i32 (sext_inreg (bfe i32:$s, i32:$o, 8), i8)),
1479-
(BFE_S32rri $s, $o, 8)>;
1480-
def : Pat<(i32 (sext_inreg (bfe i32:$s, imm:$o, 8), i8)),
1481-
(BFE_S32rii $s, imm:$o, 8)>;
1482-
def : Pat<(i32 (and (bfe i32:$s, i32:$o, 8), 255)),
1483-
(BFE_U32rri $s, $o, 8)>;
1484-
def : Pat<(i32 (and (bfe i32:$s, imm:$o, 8), 255)),
1485-
(BFE_U32rii $s, imm:$o, 8)>;
1477+
def : Pat<(i32 (sext_inreg (prmt i32:$s, 0, byte_extract_prmt:$sel, PrmtNONE), i8)),
1478+
(PRMT_B32rii $s, 0, (to_sign_extend_selector $sel), PrmtNONE)>;
14861479

14871480
// byte extraction + signed extension to i16
1488-
def : Pat<(i16 (sext_inreg (trunc (bfe i32:$s, imm:$o, 8)), i8)),
1489-
(CVT_s8_s32 (BFE_S32rii $s, imm:$o, 8), CvtNONE)>;
1481+
def : Pat<(i16 (sext_inreg (trunc (prmt i32:$s, 0, byte_extract_prmt:$sel, PrmtNONE)), i8)),
1482+
(CVT_u16_u32 (PRMT_B32rii $s, 0, (to_sign_extend_selector $sel), PrmtNONE), CvtNONE)>;
14901483

14911484

14921485
// Byte extraction via shift/trunc/sext
@@ -1699,25 +1692,33 @@ def cond_not_signed : PatLeaf<(cond), [{
16991692
// comparisons of i8 extracted with BFE as i32
17001693
// It's faster to do comparison directly on i32 extracted by BFE,
17011694
// instead of the long conversion and sign extending.
1702-
def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (bfe B32:$a, B32:$oa, 8))), i8)),
1703-
(i16 (sext_inreg (i16 (trunc (bfe B32:$b, B32:$ob, 8))), i8)),
1695+
def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), i8)),
1696+
(i16 (sext_inreg (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), i8)),
17041697
cond_signed:$cc),
1705-
(SETP_i32rr (BFE_S32rri $a, $oa, 8), (BFE_S32rri $b, $ob, 8), (cond2cc $cc))>;
1698+
(SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
1699+
(PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
1700+
(cond2cc $cc))>;
17061701

1707-
def: Pat<(setcc (i16 (sext_inreg (trunc (bfe B32:$a, imm:$oa, 8)), i8)),
1708-
(i16 (sext_inreg (trunc (bfe B32:$b, imm:$ob, 8)), i8)),
1702+
def: Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)),
1703+
(i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)),
17091704
cond_signed:$cc),
1710-
(SETP_i32rr (BFE_S32rii $a, imm:$oa, 8), (BFE_S32rii $b, imm:$ob, 8), (cond2cc $cc))>;
1705+
(SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
1706+
(PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
1707+
(cond2cc $cc))>;
17111708

1712-
def: Pat<(setcc (i16 (and (trunc (bfe B32:$a, B32:$oa, 8)), 255)),
1713-
(i16 (and (trunc (bfe B32:$b, B32:$ob, 8)), 255)),
1709+
def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
1710+
(i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
17141711
cond_signed:$cc),
1715-
(SETP_i32rr (BFE_U32rri $a, $oa, 8), (BFE_U32rri $b, $ob, 8), (cond2cc $cc))>;
1712+
(SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
1713+
(PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
1714+
(cond2cc $cc))>;
17161715

1717-
def: Pat<(setcc (i16 (and (trunc (bfe B32:$a, imm:$oa, 8)), 255)),
1718-
(i16 (and (trunc (bfe B32:$b, imm:$ob, 8)), 255)),
1716+
def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
1717+
(i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
17191718
cond_not_signed:$cc),
1720-
(SETP_i32rr (BFE_U32rii $a, imm:$oa, 8), (BFE_U32rii $b, imm:$ob, 8), (cond2cc $cc))>;
1719+
(SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
1720+
(PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
1721+
(cond2cc $cc))>;
17211722

17221723
def SDTDeclareArrayParam :
17231724
SDTypeProfile<0, 3, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>]>;

0 commit comments

Comments
 (0)