Skip to content

Commit 9aad491

Browse files
committed
[NVPTX] Remove load/store type
1 parent 793bee4 commit 9aad491

File tree

189 files changed

+8554
-8582
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

189 files changed

+8554
-8582
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 13 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,21 +1044,6 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
10441044
}
10451045
}
10461046

1047-
static int getLdStRegType(EVT VT) {
1048-
if (VT.isFloatingPoint())
1049-
switch (VT.getSimpleVT().SimpleTy) {
1050-
case MVT::f16:
1051-
case MVT::bf16:
1052-
case MVT::v2f16:
1053-
case MVT::v2bf16:
1054-
return NVPTX::PTXLdStInstCode::Untyped;
1055-
default:
1056-
return NVPTX::PTXLdStInstCode::Float;
1057-
}
1058-
else
1059-
return NVPTX::PTXLdStInstCode::Unsigned;
1060-
}
1061-
10621047
bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10631048
MemSDNode *LD = cast<MemSDNode>(N);
10641049
assert(LD->readMem() && "Expected load");
@@ -1088,24 +1073,14 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10881073
// type is integer
10891074
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
10901075
MVT SimpleVT = LoadedVT.getSimpleVT();
1091-
MVT ScalarVT = SimpleVT.getScalarType();
10921076
// Read at least 8 bits (predicates are stored as 8-bit values)
1093-
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1094-
unsigned int FromType;
1077+
unsigned FromTypeWidth = std::max(8U, (unsigned)SimpleVT.getSizeInBits());
10951078

10961079
// Vector Setting
1097-
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1098-
if (SimpleVT.isVector()) {
1099-
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
1100-
"Unexpected vector type");
1101-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1102-
FromTypeWidth = 32;
1103-
}
1104-
1105-
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
1106-
FromType = NVPTX::PTXLdStInstCode::Signed;
1107-
else
1108-
FromType = getLdStRegType(ScalarVT);
1080+
unsigned int FromType =
1081+
(PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
1082+
? NVPTX::PTXLdStInstCode::Signed
1083+
: NVPTX::PTXLdStInstCode::Untyped;
11091084

11101085
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
11111086
FromTypeWidth <= 128 && "Invalid width for load");
@@ -1116,7 +1091,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11161091
SDValue Ops[] = {getI32Imm(Ordering, DL),
11171092
getI32Imm(Scope, DL),
11181093
getI32Imm(CodeAddrSpace, DL),
1119-
getI32Imm(VecType, DL),
1094+
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
11201095
getI32Imm(FromType, DL),
11211096
getI32Imm(FromTypeWidth, DL),
11221097
Base,
@@ -1182,7 +1157,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11821157
unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
11831158
unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
11841159
? NVPTX::PTXLdStInstCode::Signed
1185-
: getLdStRegType(MemVT.getScalarType());
1160+
: NVPTX::PTXLdStInstCode::Untyped;
11861161

11871162
unsigned VecType;
11881163
unsigned FromTypeWidth;
@@ -1200,8 +1175,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12001175
}
12011176

12021177
if (isSubVectorPackedInI32(EltVT)) {
1178+
assert(ExtensionType == ISD::NON_EXTLOAD);
12031179
EltVT = MVT::i32;
1204-
FromType = NVPTX::PTXLdStInstCode::Untyped;
12051180
}
12061181

12071182
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
@@ -1405,21 +1380,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14051380
auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
14061381

14071382
// Vector Setting
1408-
MVT SimpleVT = StoreVT.getSimpleVT();
1409-
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1410-
1411-
// Type Setting: toType + toTypeWidth
1412-
// - for integer type, always use 'u'
1413-
MVT ScalarVT = SimpleVT.getScalarType();
1414-
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
1415-
if (SimpleVT.isVector()) {
1416-
assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
1417-
"Unexpected vector type");
1418-
// v2x16 is stored using st.b32
1419-
ToTypeWidth = 32;
1420-
}
1421-
1422-
unsigned int ToType = getLdStRegType(ScalarVT);
1383+
const unsigned ToTypeWidth = StoreVT.getSimpleVT().getSizeInBits();
14231384

14241385
// Create the machine instruction DAG
14251386
SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
@@ -1434,8 +1395,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14341395
getI32Imm(Ordering, DL),
14351396
getI32Imm(Scope, DL),
14361397
getI32Imm(CodeAddrSpace, DL),
1437-
getI32Imm(VecType, DL),
1438-
getI32Imm(ToType, DL),
1398+
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
1399+
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
14391400
getI32Imm(ToTypeWidth, DL),
14401401
Base,
14411402
Offset,
@@ -1481,7 +1442,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14811442
// Type Setting: toType + toTypeWidth
14821443
// - for integer type, always use 'u'
14831444
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
1484-
unsigned ToType = getLdStRegType(StoreVT.getSimpleVT().getScalarType());
14851445

14861446
SmallVector<SDValue, 12> Ops;
14871447
SDValue N2;
@@ -1508,7 +1468,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15081468

15091469
if (isSubVectorPackedInI32(EltVT)) {
15101470
EltVT = MVT::i32;
1511-
ToType = NVPTX::PTXLdStInstCode::Untyped;
15121471
}
15131472

15141473
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
@@ -1519,8 +1478,8 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15191478

15201479
Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
15211480
getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
1522-
getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL), Base, Offset,
1523-
Chain});
1481+
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
1482+
getI32Imm(ToTypeWidth, DL), Base, Offset, Chain});
15241483

15251484
std::optional<unsigned> Opcode;
15261485
switch (N->getOpcode()) {

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,11 +2249,11 @@ def LoadParamMemV2I8 : LoadParamV2MemInst<Int16Regs, ".b8">;
22492249
def LoadParamMemV4I32 : LoadParamV4MemInst<Int32Regs, ".b32">;
22502250
def LoadParamMemV4I16 : LoadParamV4MemInst<Int16Regs, ".b16">;
22512251
def LoadParamMemV4I8 : LoadParamV4MemInst<Int16Regs, ".b8">;
2252-
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".f32">;
2253-
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".f64">;
2254-
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
2255-
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">;
2256-
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".f32">;
2252+
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".b32">;
2253+
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".b64">;
2254+
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".b32">;
2255+
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".b64">;
2256+
def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".b32">;
22572257

22582258
defm StoreParamI64 : StoreParamInst<Int64Regs, i64imm, ".b64">;
22592259
defm StoreParamI32 : StoreParamInst<Int32Regs, i32imm, ".b32">;
@@ -2272,13 +2272,13 @@ defm StoreParamV4I32 : StoreParamV4Inst<Int32Regs, i32imm, ".b32">;
22722272
defm StoreParamV4I16 : StoreParamV4Inst<Int16Regs, i16imm, ".b16">;
22732273
defm StoreParamV4I8 : StoreParamV4Inst<Int16Regs, i8imm, ".b8">;
22742274

2275-
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".f32">;
2276-
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".f64">;
2275+
defm StoreParamF32 : StoreParamInst<Float32Regs, f32imm, ".b32">;
2276+
defm StoreParamF64 : StoreParamInst<Float64Regs, f64imm, ".b64">;
22772277

2278-
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".f32">;
2279-
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".f64">;
2278+
defm StoreParamV2F32 : StoreParamV2Inst<Float32Regs, f32imm, ".b32">;
2279+
defm StoreParamV2F64 : StoreParamV2Inst<Float64Regs, f64imm, ".b64">;
22802280

2281-
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".f32">;
2281+
defm StoreParamV4F32 : StoreParamV4Inst<Float32Regs, f32imm, ".b32">;
22822282

22832283
def StoreRetvalI64 : StoreRetvalInst<Int64Regs, ".b64">;
22842284
def StoreRetvalI32 : StoreRetvalInst<Int32Regs, ".b32">;
@@ -2294,11 +2294,11 @@ def StoreRetvalV4I32 : StoreRetvalV4Inst<Int32Regs, ".b32">;
22942294
def StoreRetvalV4I16 : StoreRetvalV4Inst<Int16Regs, ".b16">;
22952295
def StoreRetvalV4I8 : StoreRetvalV4Inst<Int16Regs, ".b8">;
22962296

2297-
def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".f64">;
2298-
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".f32">;
2299-
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".f64">;
2300-
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".f32">;
2301-
def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".f32">;
2297+
def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".b64">;
2298+
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".b32">;
2299+
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".b64">;
2300+
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".b32">;
2301+
def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".b32">;
23022302

23032303
def CallArgBeginInst : NVPTXInst<(outs), (ins), "(", [(CallArgBegin)]>;
23042304
def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2329,12 +2329,12 @@ class LDU_G<string TyStr, NVPTXRegClass regclass>
23292329
"ldu.global." # TyStr # " \t$result, [$src];",
23302330
[]>, Requires<[hasLDU]>;
23312331

2332-
def INT_PTX_LDU_GLOBAL_i8 : LDU_G<"u8", Int16Regs>;
2333-
def INT_PTX_LDU_GLOBAL_i16 : LDU_G<"u16", Int16Regs>;
2334-
def INT_PTX_LDU_GLOBAL_i32 : LDU_G<"u32", Int32Regs>;
2335-
def INT_PTX_LDU_GLOBAL_i64 : LDU_G<"u64", Int64Regs>;
2336-
def INT_PTX_LDU_GLOBAL_f32 : LDU_G<"f32", Float32Regs>;
2337-
def INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64", Float64Regs>;
2332+
def INT_PTX_LDU_GLOBAL_i8 : LDU_G<"b8", Int16Regs>;
2333+
def INT_PTX_LDU_GLOBAL_i16 : LDU_G<"b16", Int16Regs>;
2334+
def INT_PTX_LDU_GLOBAL_i32 : LDU_G<"b32", Int32Regs>;
2335+
def INT_PTX_LDU_GLOBAL_i64 : LDU_G<"b64", Int64Regs>;
2336+
def INT_PTX_LDU_GLOBAL_f32 : LDU_G<"b32", Float32Regs>;
2337+
def INT_PTX_LDU_GLOBAL_f64 : LDU_G<"b64", Float64Regs>;
23382338

23392339
// vector
23402340

@@ -2351,19 +2351,19 @@ class VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass>
23512351
"ldu.global.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
23522352

23532353

2354-
def INT_PTX_LDU_G_v2i8_ELE : VLDU_G_ELE_V2<"u8", Int16Regs>;
2355-
def INT_PTX_LDU_G_v2i16_ELE : VLDU_G_ELE_V2<"u16", Int16Regs>;
2356-
def INT_PTX_LDU_G_v2i32_ELE : VLDU_G_ELE_V2<"u32", Int32Regs>;
2357-
def INT_PTX_LDU_G_v2f32_ELE : VLDU_G_ELE_V2<"f32", Float32Regs>;
2358-
def INT_PTX_LDU_G_v2i64_ELE : VLDU_G_ELE_V2<"u64", Int64Regs>;
2359-
def INT_PTX_LDU_G_v2f64_ELE : VLDU_G_ELE_V2<"f64", Float64Regs>;
2354+
def INT_PTX_LDU_G_v2i8_ELE : VLDU_G_ELE_V2<"b8", Int16Regs>;
2355+
def INT_PTX_LDU_G_v2i16_ELE : VLDU_G_ELE_V2<"b16", Int16Regs>;
2356+
def INT_PTX_LDU_G_v2i32_ELE : VLDU_G_ELE_V2<"b32", Int32Regs>;
2357+
def INT_PTX_LDU_G_v2f32_ELE : VLDU_G_ELE_V2<"b32", Float32Regs>;
2358+
def INT_PTX_LDU_G_v2i64_ELE : VLDU_G_ELE_V2<"b64", Int64Regs>;
2359+
def INT_PTX_LDU_G_v2f64_ELE : VLDU_G_ELE_V2<"b64", Float64Regs>;
23602360

2361-
def INT_PTX_LDU_G_v4i8_ELE : VLDU_G_ELE_V4<"u8", Int16Regs>;
2362-
def INT_PTX_LDU_G_v4i16_ELE : VLDU_G_ELE_V4<"u16", Int16Regs>;
2363-
def INT_PTX_LDU_G_v4i32_ELE : VLDU_G_ELE_V4<"u32", Int32Regs>;
2361+
def INT_PTX_LDU_G_v4i8_ELE : VLDU_G_ELE_V4<"b8", Int16Regs>;
2362+
def INT_PTX_LDU_G_v4i16_ELE : VLDU_G_ELE_V4<"b16", Int16Regs>;
2363+
def INT_PTX_LDU_G_v4i32_ELE : VLDU_G_ELE_V4<"b32", Int32Regs>;
23642364
def INT_PTX_LDU_G_v4f16_ELE : VLDU_G_ELE_V4<"b16", Int16Regs>;
23652365
def INT_PTX_LDU_G_v4f16x2_ELE : VLDU_G_ELE_V4<"b32", Int32Regs>;
2366-
def INT_PTX_LDU_G_v4f32_ELE : VLDU_G_ELE_V4<"f32", Float32Regs>;
2366+
def INT_PTX_LDU_G_v4f32_ELE : VLDU_G_ELE_V4<"b32", Float32Regs>;
23672367

23682368

23692369
//-----------------------------------
@@ -2379,12 +2379,12 @@ class LDG_G<string TyStr, NVPTXRegClass regclass>
23792379
"ld.global.nc." # TyStr # " \t$result, [$src];",
23802380
[]>, Requires<[hasLDG]>;
23812381

2382-
def INT_PTX_LDG_GLOBAL_i8 : LDG_G<"u8", Int16Regs>;
2383-
def INT_PTX_LDG_GLOBAL_i16 : LDG_G<"u16", Int16Regs>;
2384-
def INT_PTX_LDG_GLOBAL_i32 : LDG_G<"u32", Int32Regs>;
2385-
def INT_PTX_LDG_GLOBAL_i64 : LDG_G<"u64", Int64Regs>;
2386-
def INT_PTX_LDG_GLOBAL_f32 : LDG_G<"f32", Float32Regs>;
2387-
def INT_PTX_LDG_GLOBAL_f64 : LDG_G<"f64", Float64Regs>;
2382+
def INT_PTX_LDG_GLOBAL_i8 : LDG_G<"b8", Int16Regs>;
2383+
def INT_PTX_LDG_GLOBAL_i16 : LDG_G<"b16", Int16Regs>;
2384+
def INT_PTX_LDG_GLOBAL_i32 : LDG_G<"b32", Int32Regs>;
2385+
def INT_PTX_LDG_GLOBAL_i64 : LDG_G<"b64", Int64Regs>;
2386+
def INT_PTX_LDG_GLOBAL_f32 : LDG_G<"b32", Float32Regs>;
2387+
def INT_PTX_LDG_GLOBAL_f64 : LDG_G<"b64", Float64Regs>;
23882388

23892389
// vector
23902390

@@ -2401,17 +2401,17 @@ class VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> :
24012401
"ld.global.nc.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
24022402

24032403
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
2404-
def INT_PTX_LDG_G_v2i8_ELE : VLDG_G_ELE_V2<"u8", Int16Regs>;
2405-
def INT_PTX_LDG_G_v2i16_ELE : VLDG_G_ELE_V2<"u16", Int16Regs>;
2406-
def INT_PTX_LDG_G_v2i32_ELE : VLDG_G_ELE_V2<"u32", Int32Regs>;
2407-
def INT_PTX_LDG_G_v2f32_ELE : VLDG_G_ELE_V2<"f32", Float32Regs>;
2408-
def INT_PTX_LDG_G_v2i64_ELE : VLDG_G_ELE_V2<"u64", Int64Regs>;
2409-
def INT_PTX_LDG_G_v2f64_ELE : VLDG_G_ELE_V2<"f64", Float64Regs>;
2410-
2411-
def INT_PTX_LDG_G_v4i8_ELE : VLDG_G_ELE_V4<"u8", Int16Regs>;
2412-
def INT_PTX_LDG_G_v4i16_ELE : VLDG_G_ELE_V4<"u16", Int16Regs>;
2413-
def INT_PTX_LDG_G_v4i32_ELE : VLDG_G_ELE_V4<"u32", Int32Regs>;
2414-
def INT_PTX_LDG_G_v4f32_ELE : VLDG_G_ELE_V4<"f32", Float32Regs>;
2404+
def INT_PTX_LDG_G_v2i8_ELE : VLDG_G_ELE_V2<"b8", Int16Regs>;
2405+
def INT_PTX_LDG_G_v2i16_ELE : VLDG_G_ELE_V2<"b16", Int16Regs>;
2406+
def INT_PTX_LDG_G_v2i32_ELE : VLDG_G_ELE_V2<"b32", Int32Regs>;
2407+
def INT_PTX_LDG_G_v2f32_ELE : VLDG_G_ELE_V2<"b32", Float32Regs>;
2408+
def INT_PTX_LDG_G_v2i64_ELE : VLDG_G_ELE_V2<"b64", Int64Regs>;
2409+
def INT_PTX_LDG_G_v2f64_ELE : VLDG_G_ELE_V2<"b64", Float64Regs>;
2410+
2411+
def INT_PTX_LDG_G_v4i8_ELE : VLDG_G_ELE_V4<"b8", Int16Regs>;
2412+
def INT_PTX_LDG_G_v4i16_ELE : VLDG_G_ELE_V4<"b16", Int16Regs>;
2413+
def INT_PTX_LDG_G_v4i32_ELE : VLDG_G_ELE_V4<"b32", Int32Regs>;
2414+
def INT_PTX_LDG_G_v4f32_ELE : VLDG_G_ELE_V4<"b32", Float32Regs>;
24152415

24162416

24172417
multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {

0 commit comments

Comments
 (0)