Skip to content

Commit a0b56b6

Browse files
committed
[NVPTX] Vectorize and lower 256-bit global loads/stores for sm_100+ and ptx88+
1 parent 790ce0e commit a0b56b6

15 files changed

+3964
-21
lines changed

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,9 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
319319
case NVPTX::PTXLdStInstCode::V4:
320320
O << ".v4";
321321
return;
322+
case NVPTX::PTXLdStInstCode::V8:
323+
O << ".v8";
324+
return;
322325
}
323326
// TODO: evaluate whether cases not covered by this switch are bugs
324327
return;

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ enum FromType {
199199
enum VecType {
200200
Scalar = 1,
201201
V2 = 2,
202-
V4 = 4
202+
V4 = 4,
203+
V8 = 8
203204
};
204205
} // namespace PTXLdStInstCode
205206

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
129129
return;
130130
case NVPTXISD::LoadV2:
131131
case NVPTXISD::LoadV4:
132+
case NVPTXISD::LoadV8:
132133
if (tryLoadVector(N))
133134
return;
134135
break;
@@ -139,6 +140,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
139140
break;
140141
case NVPTXISD::StoreV2:
141142
case NVPTXISD::StoreV4:
143+
case NVPTXISD::StoreV8:
142144
if (tryStoreVector(N))
143145
return;
144146
break;
@@ -1195,6 +1197,12 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11951197
FromTypeWidth = TotalWidth / 4;
11961198
VecType = NVPTX::PTXLdStInstCode::V4;
11971199
break;
1200+
case NVPTXISD::LoadV8:
1201+
if (!Subtarget->has256BitMaskedLoadStore())
1202+
return false;
1203+
FromTypeWidth = TotalWidth / 8;
1204+
VecType = NVPTX::PTXLdStInstCode::V8;
1205+
break;
11981206
default:
11991207
return false;
12001208
}
@@ -1205,7 +1213,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12051213
}
12061214

12071215
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
1208-
FromTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for load");
1216+
FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
12091217

12101218
SDValue Offset, Base;
12111219
SelectADDR(N->getOperand(1), Base, Offset);
@@ -1230,9 +1238,22 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12301238
NVPTX::LDV_f32_v2, NVPTX::LDV_f64_v2);
12311239
break;
12321240
case NVPTXISD::LoadV4:
1233-
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
1234-
NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, std::nullopt,
1235-
NVPTX::LDV_f32_v4, std::nullopt);
1241+
Opcode =
1242+
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
1243+
NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4,
1244+
NVPTX::LDV_f32_v4, NVPTX::LDV_f64_v4);
1245+
break;
1246+
case NVPTXISD::LoadV8:
1247+
switch (EltVT.getSimpleVT().SimpleTy) {
1248+
case MVT::i32:
1249+
Opcode = NVPTX::LDV_i32_v8;
1250+
break;
1251+
case MVT::f32:
1252+
Opcode = NVPTX::LDV_f32_v8;
1253+
break;
1254+
default:
1255+
return false;
1256+
}
12361257
break;
12371258
}
12381259
if (!Opcode)
@@ -1328,14 +1349,33 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13281349
Opcode = pickOpcodeForVT(
13291350
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
13301351
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1331-
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1352+
NVPTX::INT_PTX_LDG_G_v4i64_ELE, NVPTX::INT_PTX_LDG_G_v4f32_ELE,
1353+
NVPTX::INT_PTX_LDG_G_v4f64_ELE);
13321354
break;
13331355
case NVPTXISD::LDUV4:
13341356
Opcode = pickOpcodeForVT(
13351357
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
13361358
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
13371359
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
13381360
break;
1361+
case NVPTXISD::LoadV8:
1362+
switch (EltVT.getSimpleVT().SimpleTy) {
1363+
case MVT::i32:
1364+
Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
1365+
break;
1366+
case MVT::f32:
1367+
Opcode = NVPTX::INT_PTX_LDG_G_v8f32_ELE;
1368+
break;
1369+
case MVT::v2i16:
1370+
case MVT::v2f16:
1371+
case MVT::v2bf16:
1372+
case MVT::v4i8:
1373+
Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
1374+
break;
1375+
default:
1376+
return false;
1377+
}
1378+
break;
13391379
}
13401380
if (!Opcode)
13411381
return false;
@@ -1502,6 +1542,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15021542
N2 = N->getOperand(5);
15031543
ToTypeWidth = TotalWidth / 4;
15041544
break;
1545+
case NVPTXISD::StoreV8:
1546+
if (!Subtarget->has256BitMaskedLoadStore())
1547+
return false;
1548+
VecType = NVPTX::PTXLdStInstCode::V8;
1549+
Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
1550+
N->getOperand(4), N->getOperand(5), N->getOperand(6),
1551+
N->getOperand(7), N->getOperand(8)});
1552+
N2 = N->getOperand(9);
1553+
ToTypeWidth = TotalWidth / 8;
1554+
break;
15051555
default:
15061556
return false;
15071557
}
@@ -1512,7 +1562,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15121562
}
15131563

15141564
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
1515-
TotalWidth <= 128 && "Invalid width for store");
1565+
TotalWidth <= 256 && "Invalid width for store");
15161566

15171567
SDValue Offset, Base;
15181568
SelectADDR(N2, Base, Offset);
@@ -1533,9 +1583,22 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15331583
NVPTX::STV_f32_v2, NVPTX::STV_f64_v2);
15341584
break;
15351585
case NVPTXISD::StoreV4:
1536-
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
1537-
NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, std::nullopt,
1538-
NVPTX::STV_f32_v4, std::nullopt);
1586+
Opcode =
1587+
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
1588+
NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, NVPTX::STV_i64_v4,
1589+
NVPTX::STV_f32_v4, NVPTX::STV_f64_v4);
1590+
break;
1591+
case NVPTXISD::StoreV8:
1592+
switch (EltVT.getSimpleVT().SimpleTy) {
1593+
case MVT::i32:
1594+
Opcode = NVPTX::STV_i32_v8;
1595+
break;
1596+
case MVT::f32:
1597+
Opcode = NVPTX::STV_f32_v8;
1598+
break;
1599+
default:
1600+
return false;
1601+
}
15391602
break;
15401603
}
15411604

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,14 @@ static bool IsPTXVectorType(MVT VT) {
162162
case MVT::v2f32:
163163
case MVT::v4f32:
164164
case MVT::v2f64:
165+
case MVT::v4i64:
166+
case MVT::v4f64:
167+
case MVT::v8i32:
168+
case MVT::v8f32:
169+
case MVT::v16f16: // <8 x f16x2>
170+
case MVT::v16bf16: // <8 x bf16x2>
171+
case MVT::v16i16: // <8 x i16x2>
172+
case MVT::v32i8: // <8 x i8x4>
165173
return true;
166174
}
167175
}
@@ -179,7 +187,7 @@ static bool Is16bitsType(MVT VT) {
179187
// - unsigned int NumElts - The number of elements in the final vector
180188
// - EVT EltVT - The type of the elements in the final vector
181189
static std::optional<std::pair<unsigned int, MVT>>
182-
getVectorLoweringShape(EVT VectorEVT) {
190+
getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
183191
if (!VectorEVT.isSimple())
184192
return std::nullopt;
185193
const MVT VectorVT = VectorEVT.getSimpleVT();
@@ -199,6 +207,15 @@ getVectorLoweringShape(EVT VectorEVT) {
199207
switch (VectorVT.SimpleTy) {
200208
default:
201209
return std::nullopt;
210+
case MVT::v4i64:
211+
case MVT::v4f64:
212+
case MVT::v8i32:
213+
case MVT::v8f32:
214+
// This is a "native" vector type iff the address space is global
215+
// and the target supports 256-bit loads/stores
216+
if (!CanLowerTo256Bit)
217+
return std::nullopt;
218+
LLVM_FALLTHROUGH;
202219
case MVT::v2i8:
203220
case MVT::v2i16:
204221
case MVT::v2i32:
@@ -215,6 +232,15 @@ getVectorLoweringShape(EVT VectorEVT) {
215232
case MVT::v4f32:
216233
// This is a "native" vector type
217234
return std::pair(NumElts, EltVT);
235+
case MVT::v16f16: // <8 x f16x2>
236+
case MVT::v16bf16: // <8 x bf16x2>
237+
case MVT::v16i16: // <8 x i16x2>
238+
case MVT::v32i8: // <8 x i8x4>
239+
// This can be upsized into a "native" vector type iff the address space is
240+
// global and the target supports 256-bit loads/stores.
241+
if (!CanLowerTo256Bit)
242+
return std::nullopt;
243+
LLVM_FALLTHROUGH;
218244
case MVT::v8i8: // <2 x i8x4>
219245
case MVT::v8f16: // <4 x f16x2>
220246
case MVT::v8bf16: // <4 x bf16x2>
@@ -1070,10 +1096,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10701096
MAKE_CASE(NVPTXISD::ProxyReg)
10711097
MAKE_CASE(NVPTXISD::LoadV2)
10721098
MAKE_CASE(NVPTXISD::LoadV4)
1099+
MAKE_CASE(NVPTXISD::LoadV8)
10731100
MAKE_CASE(NVPTXISD::LDUV2)
10741101
MAKE_CASE(NVPTXISD::LDUV4)
10751102
MAKE_CASE(NVPTXISD::StoreV2)
10761103
MAKE_CASE(NVPTXISD::StoreV4)
1104+
MAKE_CASE(NVPTXISD::StoreV8)
10771105
MAKE_CASE(NVPTXISD::FSHL_CLAMP)
10781106
MAKE_CASE(NVPTXISD::FSHR_CLAMP)
10791107
MAKE_CASE(NVPTXISD::BFE)
@@ -3201,7 +3229,12 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32013229
if (ValVT != MemVT)
32023230
return SDValue();
32033231

3204-
const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
3232+
// 256-bit vectors are only allowed iff the address is global
3233+
// and the target supports 256-bit loads/stores
3234+
unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
3235+
bool CanLowerTo256Bit =
3236+
AddrSpace == ADDRESS_SPACE_GLOBAL && STI.has256BitMaskedLoadStore();
3237+
const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT, CanLowerTo256Bit);
32053238
if (!NumEltsAndEltVT)
32063239
return SDValue();
32073240
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -3229,6 +3262,9 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32293262
case 4:
32303263
Opcode = NVPTXISD::StoreV4;
32313264
break;
3265+
case 8:
3266+
Opcode = NVPTXISD::StoreV8;
3267+
break;
32323268
}
32333269

32343270
SmallVector<SDValue, 8> Ops;
@@ -5765,7 +5801,8 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
57655801

57665802
/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
57675803
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5768-
SmallVectorImpl<SDValue> &Results) {
5804+
SmallVectorImpl<SDValue> &Results,
5805+
bool TargetHas256BitVectorLoadStore) {
57695806
LoadSDNode *LD = cast<LoadSDNode>(N);
57705807
const EVT ResVT = LD->getValueType(0);
57715808
const EVT MemVT = LD->getMemoryVT();
@@ -5775,7 +5812,12 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
57755812
if (ResVT != MemVT)
57765813
return;
57775814

5778-
const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
5815+
// 256-bit vectors are only allowed iff the address is global
5816+
// and the target supports 256-bit loads/stores
5817+
unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
5818+
bool CanLowerTo256Bit =
5819+
AddrSpace == ADDRESS_SPACE_GLOBAL && TargetHas256BitVectorLoadStore;
5820+
const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT, CanLowerTo256Bit);
57795821
if (!NumEltsAndEltVT)
57805822
return;
57815823
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -5812,6 +5854,13 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
58125854
DAG.getVTList({LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other});
58135855
break;
58145856
}
5857+
case 8: {
5858+
Opcode = NVPTXISD::LoadV8;
5859+
EVT ListVTs[] = {LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT,
5860+
LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other};
5861+
LdResVTs = DAG.getVTList(ListVTs);
5862+
break;
5863+
}
58155864
}
58165865
SDLoc DL(LD);
58175866

@@ -6084,7 +6133,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
60846133
ReplaceBITCAST(N, DAG, Results);
60856134
return;
60866135
case ISD::LOAD:
6087-
ReplaceLoadVector(N, DAG, Results);
6136+
ReplaceLoadVector(N, DAG, Results, STI.has256BitMaskedLoadStore());
60886137
return;
60896138
case ISD::INTRINSIC_W_CHAIN:
60906139
ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ enum NodeType : unsigned {
8484
FIRST_MEMORY_OPCODE,
8585
LoadV2 = FIRST_MEMORY_OPCODE,
8686
LoadV4,
87+
LoadV8,
8788
LDUV2, // LDU.v2
8889
LDUV4, // LDU.v4
8990
StoreV2,
9091
StoreV4,
92+
StoreV8,
9193
LoadParam,
9294
LoadParamV2,
9395
LoadParamV4,

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,7 +2425,7 @@ let mayStore=1, hasSideEffects=0 in {
24252425
// The following is used only in and after vector elementizations. Vector
24262426
// elementization happens at the machine instruction level, so the following
24272427
// instructions never appear in the DAG.
2428-
multiclass LD_VEC<NVPTXRegClass regclass> {
2428+
multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
24292429
def _v2 : NVPTXInst<
24302430
(outs regclass:$dst1, regclass:$dst2),
24312431
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
@@ -2438,17 +2438,27 @@ multiclass LD_VEC<NVPTXRegClass regclass> {
24382438
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
24392439
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
24402440
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];", []>;
2441+
if support_v8 then {
2442+
def _v8 : NVPTXInst<
2443+
(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
2444+
regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
2445+
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
2446+
i32imm:$fromWidth, ADDR:$addr),
2447+
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
2448+
"\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, "
2449+
"[$addr];", []>;
2450+
}
24412451
}
24422452
let mayLoad=1, hasSideEffects=0 in {
24432453
defm LDV_i8 : LD_VEC<Int16Regs>;
24442454
defm LDV_i16 : LD_VEC<Int16Regs>;
2445-
defm LDV_i32 : LD_VEC<Int32Regs>;
2455+
defm LDV_i32 : LD_VEC<Int32Regs, true>;
24462456
defm LDV_i64 : LD_VEC<Int64Regs>;
2447-
defm LDV_f32 : LD_VEC<Float32Regs>;
2457+
defm LDV_f32 : LD_VEC<Float32Regs, true>;
24482458
defm LDV_f64 : LD_VEC<Float64Regs>;
24492459
}
24502460

2451-
multiclass ST_VEC<NVPTXRegClass regclass> {
2461+
multiclass ST_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
24522462
def _v2 : NVPTXInst<
24532463
(outs),
24542464
(ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
@@ -2463,14 +2473,25 @@ multiclass ST_VEC<NVPTXRegClass regclass> {
24632473
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
24642474
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
24652475
"\t[$addr], {{$src1, $src2, $src3, $src4}};", []>;
2476+
if support_v8 then {
2477+
def _v8 : NVPTXInst<
2478+
(outs),
2479+
(ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
2480+
regclass:$src5, regclass:$src6, regclass:$src7, regclass:$src8,
2481+
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
2482+
i32imm:$fromWidth, ADDR:$addr),
2483+
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
2484+
"\t[$addr], "
2485+
"{{$src1, $src2, $src3, $src4, $src5, $src6, $src7, $src8}};", []>;
2486+
}
24662487
}
24672488

24682489
let mayStore=1, hasSideEffects=0 in {
24692490
defm STV_i8 : ST_VEC<Int16Regs>;
24702491
defm STV_i16 : ST_VEC<Int16Regs>;
2471-
defm STV_i32 : ST_VEC<Int32Regs>;
2492+
defm STV_i32 : ST_VEC<Int32Regs, true>;
24722493
defm STV_i64 : ST_VEC<Int64Regs>;
2473-
defm STV_f32 : ST_VEC<Float32Regs>;
2494+
defm STV_f32 : ST_VEC<Float32Regs, true>;
24742495
defm STV_f64 : ST_VEC<Float64Regs>;
24752496
}
24762497

0 commit comments

Comments
 (0)