Skip to content

Commit a1e1a84

Browse files
authored
[NVPTX] Vectorize and lower 256-bit global loads/stores for sm_100+/ptx88+ (#139292)
PTX 8.8+ introduces 256-bit-wide vector loads/stores under certain conditions. This change extends the backend to lower these loads/stores. It also overrides getLoadStoreVecRegBitWidth for NVPTX, allowing the LoadStoreVectorizer to create these wider vector operations. See the spec for the three relevant PTX instructions here: - https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld - https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld-global-nc - https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st
1 parent 952b680 commit a1e1a84

18 files changed

+3966
-118
lines changed

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -311,17 +311,6 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
311311
default:
312312
llvm_unreachable("Unknown register type");
313313
}
314-
} else if (Modifier == "vec") {
315-
switch (Imm) {
316-
case NVPTX::PTXLdStInstCode::V2:
317-
O << ".v2";
318-
return;
319-
case NVPTX::PTXLdStInstCode::V4:
320-
O << ".v4";
321-
return;
322-
}
323-
// TODO: evaluate whether cases not covered by this switch are bugs
324-
return;
325314
}
326315
llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str());
327316
}

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,7 @@ enum AddressSpace : AddressSpaceUnderlyingType {
190190
};
191191

192192
namespace PTXLdStInstCode {
193-
enum FromType {
194-
Unsigned = 0,
195-
Signed,
196-
Float,
197-
Untyped
198-
};
199-
enum VecType {
200-
Scalar = 1,
201-
V2 = 2,
202-
V4 = 4
203-
};
193+
enum FromType { Unsigned = 0, Signed, Float, Untyped };
204194
} // namespace PTXLdStInstCode
205195

206196
/// PTXCvtMode - Conversion code enumeration

llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
105105
const MachineOperand *ParamSymbol = Mov.uses().begin();
106106
assert(ParamSymbol->isSymbol());
107107

108-
constexpr unsigned LDInstBasePtrOpIdx = 6;
108+
constexpr unsigned LDInstBasePtrOpIdx = 5;
109109
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
110110
for (auto *LI : LoadInsts) {
111111
(LI->uses().begin() + LDInstBasePtrOpIdx)

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
129129
return;
130130
case NVPTXISD::LoadV2:
131131
case NVPTXISD::LoadV4:
132+
case NVPTXISD::LoadV8:
132133
if (tryLoadVector(N))
133134
return;
134135
break;
@@ -139,6 +140,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
139140
break;
140141
case NVPTXISD::StoreV2:
141142
case NVPTXISD::StoreV4:
143+
case NVPTXISD::StoreV8:
142144
if (tryStoreVector(N))
143145
return;
144146
break;
@@ -1012,11 +1014,11 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
10121014

10131015
// Helper function template to reduce amount of boilerplate code for
10141016
// opcode selection.
1015-
static std::optional<unsigned>
1016-
pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
1017-
unsigned Opcode_i16, unsigned Opcode_i32,
1018-
std::optional<unsigned> Opcode_i64, unsigned Opcode_f32,
1019-
std::optional<unsigned> Opcode_f64) {
1017+
static std::optional<unsigned> pickOpcodeForVT(
1018+
MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8,
1019+
std::optional<unsigned> Opcode_i16, std::optional<unsigned> Opcode_i32,
1020+
std::optional<unsigned> Opcode_i64, std::optional<unsigned> Opcode_f32,
1021+
std::optional<unsigned> Opcode_f64) {
10201022
switch (VT) {
10211023
case MVT::i1:
10221024
case MVT::i8:
@@ -1091,7 +1093,6 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10911093
SDValue Ops[] = {getI32Imm(Ordering, DL),
10921094
getI32Imm(Scope, DL),
10931095
getI32Imm(CodeAddrSpace, DL),
1094-
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
10951096
getI32Imm(FromType, DL),
10961097
getI32Imm(FromTypeWidth, DL),
10971098
Base,
@@ -1128,6 +1129,22 @@ static bool isSubVectorPackedInI32(EVT EltVT) {
11281129
return Isv2x16VT(EltVT) || EltVT == MVT::v4i8;
11291130
}
11301131

1132+
static unsigned getLoadStoreVectorNumElts(SDNode *N) {
1133+
switch (N->getOpcode()) {
1134+
case NVPTXISD::LoadV2:
1135+
case NVPTXISD::StoreV2:
1136+
return 2;
1137+
case NVPTXISD::LoadV4:
1138+
case NVPTXISD::StoreV4:
1139+
return 4;
1140+
case NVPTXISD::LoadV8:
1141+
case NVPTXISD::StoreV8:
1142+
return 8;
1143+
default:
1144+
llvm_unreachable("Unexpected opcode");
1145+
}
1146+
}
1147+
11311148
bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11321149
MemSDNode *MemSD = cast<MemSDNode>(N);
11331150
const EVT MemEVT = MemSD->getMemoryVT();
@@ -1159,35 +1176,21 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11591176
? NVPTX::PTXLdStInstCode::Signed
11601177
: NVPTX::PTXLdStInstCode::Untyped;
11611178

1162-
unsigned VecType;
1163-
unsigned FromTypeWidth;
1164-
switch (N->getOpcode()) {
1165-
case NVPTXISD::LoadV2:
1166-
FromTypeWidth = TotalWidth / 2;
1167-
VecType = NVPTX::PTXLdStInstCode::V2;
1168-
break;
1169-
case NVPTXISD::LoadV4:
1170-
FromTypeWidth = TotalWidth / 4;
1171-
VecType = NVPTX::PTXLdStInstCode::V4;
1172-
break;
1173-
default:
1174-
return false;
1175-
}
1179+
unsigned FromTypeWidth = TotalWidth / getLoadStoreVectorNumElts(N);
11761180

11771181
if (isSubVectorPackedInI32(EltVT)) {
11781182
assert(ExtensionType == ISD::NON_EXTLOAD);
11791183
EltVT = MVT::i32;
11801184
}
11811185

11821186
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
1183-
FromTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for load");
1187+
FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
11841188

11851189
SDValue Offset, Base;
11861190
SelectADDR(N->getOperand(1), Base, Offset);
11871191
SDValue Ops[] = {getI32Imm(Ordering, DL),
11881192
getI32Imm(Scope, DL),
11891193
getI32Imm(CodeAddrSpace, DL),
1190-
getI32Imm(VecType, DL),
11911194
getI32Imm(FromType, DL),
11921195
getI32Imm(FromTypeWidth, DL),
11931196
Base,
@@ -1205,9 +1208,16 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12051208
NVPTX::LDV_f32_v2, NVPTX::LDV_f64_v2);
12061209
break;
12071210
case NVPTXISD::LoadV4:
1208-
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
1209-
NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, std::nullopt,
1210-
NVPTX::LDV_f32_v4, std::nullopt);
1211+
Opcode =
1212+
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
1213+
NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4,
1214+
NVPTX::LDV_f32_v4, NVPTX::LDV_f64_v4);
1215+
break;
1216+
case NVPTXISD::LoadV8:
1217+
Opcode =
1218+
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
1219+
{/* no v8i16 */}, NVPTX::LDV_i32_v8, {/* no v8i64 */},
1220+
NVPTX::LDV_f32_v8, {/* no v8f64 */});
12111221
break;
12121222
}
12131223
if (!Opcode)
@@ -1303,13 +1313,20 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13031313
Opcode = pickOpcodeForVT(
13041314
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
13051315
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1306-
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1316+
NVPTX::INT_PTX_LDG_G_v4i64_ELE, NVPTX::INT_PTX_LDG_G_v4f32_ELE,
1317+
NVPTX::INT_PTX_LDG_G_v4f64_ELE);
13071318
break;
13081319
case NVPTXISD::LDUV4:
13091320
Opcode = pickOpcodeForVT(
13101321
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
13111322
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1312-
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1323+
{/* no v4i64 */}, NVPTX::INT_PTX_LDU_G_v4f32_ELE, {/* no v4f64 */});
1324+
break;
1325+
case NVPTXISD::LoadV8:
1326+
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
1327+
{/* no v8i16 */}, NVPTX::INT_PTX_LDG_G_v8i32_ELE,
1328+
{/* no v8i64 */}, NVPTX::INT_PTX_LDG_G_v8f32_ELE,
1329+
{/* no v8f64 */});
13131330
break;
13141331
}
13151332
if (!Opcode)
@@ -1395,7 +1412,6 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
13951412
getI32Imm(Ordering, DL),
13961413
getI32Imm(Scope, DL),
13971414
getI32Imm(CodeAddrSpace, DL),
1398-
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
13991415
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
14001416
getI32Imm(ToTypeWidth, DL),
14011417
Base,
@@ -1443,41 +1459,24 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14431459
// - for integer type, always use 'u'
14441460
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
14451461

1446-
SmallVector<SDValue, 12> Ops;
1447-
SDValue N2;
1448-
unsigned VecType;
1449-
unsigned ToTypeWidth;
1462+
unsigned NumElts = getLoadStoreVectorNumElts(N);
14501463

1451-
switch (N->getOpcode()) {
1452-
case NVPTXISD::StoreV2:
1453-
VecType = NVPTX::PTXLdStInstCode::V2;
1454-
Ops.append({N->getOperand(1), N->getOperand(2)});
1455-
N2 = N->getOperand(3);
1456-
ToTypeWidth = TotalWidth / 2;
1457-
break;
1458-
case NVPTXISD::StoreV4:
1459-
VecType = NVPTX::PTXLdStInstCode::V4;
1460-
Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
1461-
N->getOperand(4)});
1462-
N2 = N->getOperand(5);
1463-
ToTypeWidth = TotalWidth / 4;
1464-
break;
1465-
default:
1466-
return false;
1467-
}
1464+
SmallVector<SDValue, 16> Ops(N->ops().slice(1, NumElts));
1465+
SDValue N2 = N->getOperand(NumElts + 1);
1466+
unsigned ToTypeWidth = TotalWidth / NumElts;
14681467

14691468
if (isSubVectorPackedInI32(EltVT)) {
14701469
EltVT = MVT::i32;
14711470
}
14721471

14731472
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
1474-
TotalWidth <= 128 && "Invalid width for store");
1473+
TotalWidth <= 256 && "Invalid width for store");
14751474

14761475
SDValue Offset, Base;
14771476
SelectADDR(N2, Base, Offset);
14781477

14791478
Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
1480-
getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
1479+
getI32Imm(CodeAddrSpace, DL),
14811480
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
14821481
getI32Imm(ToTypeWidth, DL), Base, Offset, Chain});
14831482

@@ -1492,9 +1491,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14921491
NVPTX::STV_f32_v2, NVPTX::STV_f64_v2);
14931492
break;
14941493
case NVPTXISD::StoreV4:
1495-
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
1496-
NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, std::nullopt,
1497-
NVPTX::STV_f32_v4, std::nullopt);
1494+
Opcode =
1495+
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
1496+
NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, NVPTX::STV_i64_v4,
1497+
NVPTX::STV_f32_v4, NVPTX::STV_f64_v4);
1498+
break;
1499+
case NVPTXISD::StoreV8:
1500+
Opcode =
1501+
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
1502+
{/* no v8i16 */}, NVPTX::STV_i32_v8, {/* no v8i64 */},
1503+
NVPTX::STV_f32_v8, {/* no v8f64 */});
14981504
break;
14991505
}
15001506

@@ -1554,10 +1560,10 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
15541560
NVPTX::LoadParamMemV2F64);
15551561
break;
15561562
case 4:
1557-
Opcode =
1558-
pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV4I8,
1559-
NVPTX::LoadParamMemV4I16, NVPTX::LoadParamMemV4I32,
1560-
std::nullopt, NVPTX::LoadParamMemV4F32, std::nullopt);
1563+
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
1564+
NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
1565+
NVPTX::LoadParamMemV4I32, {/* no v4i64 */},
1566+
NVPTX::LoadParamMemV4F32, {/* no v4f64 */});
15611567
break;
15621568
}
15631569
if (!Opcode)
@@ -1648,8 +1654,8 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
16481654
case 4:
16491655
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
16501656
NVPTX::StoreRetvalV4I8, NVPTX::StoreRetvalV4I16,
1651-
NVPTX::StoreRetvalV4I32, std::nullopt,
1652-
NVPTX::StoreRetvalV4F32, std::nullopt);
1657+
NVPTX::StoreRetvalV4I32, {/* no v4i64 */},
1658+
NVPTX::StoreRetvalV4F32, {/* no v4f64 */});
16531659
break;
16541660
}
16551661
if (!Opcode)

0 commit comments

Comments
 (0)