@@ -1278,6 +1278,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12781278 EltVT = MVT::i64 ;
12791279 NumElts = 2 ;
12801280 }
1281+
1282+ std::optional<unsigned > Opcode;
1283+
12811284 if (EltVT.isVector ()) {
12821285 NumElts = EltVT.getVectorNumElements ();
12831286 EltVT = EltVT.getVectorElementType ();
@@ -1290,6 +1293,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12901293 (EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
12911294 assert (NumElts % OrigType.getVectorNumElements () == 0 &&
12921295 " NumElts must be divisible by the number of elts in subvectors" );
1296+ if (N->getOpcode () == ISD::LOAD ||
1297+ N->getOpcode () == ISD::INTRINSIC_W_CHAIN) {
1298+ switch (OrigType.getSimpleVT ().SimpleTy ) {
1299+ case MVT::v2f32:
1300+ Opcode = N->getOpcode () == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b64
1301+ : NVPTX::INT_PTX_LDU_GLOBAL_b64;
1302+ break ;
1303+ case MVT::v2f16:
1304+ case MVT::v2bf16:
1305+ case MVT::v2i16:
1306+ case MVT::v4i8:
1307+ Opcode = N->getOpcode () == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b32
1308+ : NVPTX::INT_PTX_LDU_GLOBAL_b32;
1309+ break ;
1310+ default :
1311+ llvm_unreachable (" Unhandled packed vector type" );
1312+ }
1313+ }
12931314 EltVT = OrigType;
12941315 NumElts /= OrigType.getVectorNumElements ();
12951316 }
@@ -1309,50 +1330,51 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13091330 SelectADDR (Op1, Base, Offset);
13101331 SDValue Ops[] = {Base, Offset, Chain};
13111332
1312- std::optional<unsigned > Opcode;
1313- switch (N->getOpcode ()) {
1314- default :
1315- return false ;
1316- case ISD::LOAD:
1317- Opcode = pickOpcodeForVT (
1318- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_GLOBAL_i8,
1319- NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1320- NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1321- NVPTX::INT_PTX_LDG_GLOBAL_f64);
1322- break ;
1323- case ISD::INTRINSIC_W_CHAIN:
1324- Opcode = pickOpcodeForVT (
1325- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_GLOBAL_i8,
1326- NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1327- NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1328- NVPTX::INT_PTX_LDU_GLOBAL_f64);
1329- break ;
1330- case NVPTXISD::LoadV2:
1331- Opcode = pickOpcodeForVT (
1332- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1333- NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1334- NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1335- NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1336- break ;
1337- case NVPTXISD::LDUV2:
1338- Opcode = pickOpcodeForVT (
1339- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1340- NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1341- NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1342- NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1343- break ;
1344- case NVPTXISD::LoadV4:
1345- Opcode = pickOpcodeForVT (
1346- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1347- NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1348- std::nullopt , NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt );
1349- break ;
1350- case NVPTXISD::LDUV4:
1351- Opcode = pickOpcodeForVT (
1352- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1353- NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1354- std::nullopt , NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt );
1355- break ;
1333+ if (!Opcode) {
1334+ switch (N->getOpcode ()) {
1335+ default :
1336+ return false ;
1337+ case ISD::LOAD:
1338+ Opcode = pickOpcodeForVT (
1339+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_GLOBAL_i8,
1340+ NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1341+ NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1342+ NVPTX::INT_PTX_LDG_GLOBAL_f64);
1343+ break ;
1344+ case ISD::INTRINSIC_W_CHAIN:
1345+ Opcode = pickOpcodeForVT (
1346+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_GLOBAL_i8,
1347+ NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1348+ NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1349+ NVPTX::INT_PTX_LDU_GLOBAL_f64);
1350+ break ;
1351+ case NVPTXISD::LoadV2:
1352+ Opcode = pickOpcodeForVT (
1353+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1354+ NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1355+ NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1356+ NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1357+ break ;
1358+ case NVPTXISD::LDUV2:
1359+ Opcode = pickOpcodeForVT (
1360+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1361+ NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1362+ NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1363+ NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1364+ break ;
1365+ case NVPTXISD::LoadV4:
1366+ Opcode = pickOpcodeForVT (
1367+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1368+ NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1369+ std::nullopt , NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt );
1370+ break ;
1371+ case NVPTXISD::LDUV4:
1372+ Opcode = pickOpcodeForVT (
1373+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1374+ NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1375+ std::nullopt , NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt );
1376+ break ;
1377+ }
13561378 }
13571379 if (!Opcode)
13581380 return false ;
0 commit comments