@@ -1173,7 +1173,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11731173 return true ;
11741174}
11751175
1176- static bool isVectorElementTypeUpsized (EVT EltVT) {
1176+ static bool isSubVectorPackedInI32 (EVT EltVT) {
11771177 // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
11781178 // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
11791179 // vectorized loads/stores with the actual element type for i8/i16 as that
@@ -1186,60 +1186,54 @@ static bool isVectorElementTypeUpsized(EVT EltVT) {
11861186
11871187bool NVPTXDAGToDAGISel::tryLoadVector (SDNode *N) {
11881188 MemSDNode *MemSD = cast<MemSDNode>(N);
1189- EVT LoadedVT = MemSD->getMemoryVT ();
1190- if (!LoadedVT .isSimple ())
1189+ const EVT MemEVT = MemSD->getMemoryVT ();
1190+ if (!MemEVT .isSimple ())
11911191 return false ;
1192+ const MVT MemVT = MemEVT.getSimpleVT ();
11921193
11931194 // Address Space Setting
11941195 unsigned int CodeAddrSpace = getCodeAddrSpace (MemSD);
11951196 if (canLowerToLDG (MemSD, *Subtarget, CodeAddrSpace, MF)) {
11961197 return tryLDGLDU (N);
11971198 }
11981199
1200+ EVT EltVT = N->getValueType (0 );
11991201 SDLoc DL (N);
12001202 SDValue Chain = N->getOperand (0 );
12011203 auto [Ordering, Scope] = insertMemoryInstructionFence (DL, Chain, MemSD);
12021204
1203- // Vector Setting
1204- MVT SimpleVT = LoadedVT.getSimpleVT ();
1205-
12061205 // Type Setting: fromType + fromTypeWidth
12071206 //
12081207 // Sign : ISD::SEXTLOAD
12091208 // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
12101209 // type is integer
12111210 // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1212- MVT ScalarVT = SimpleVT.getScalarType ();
12131211 // Read at least 8 bits (predicates are stored as 8-bit values)
1214- unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1215- unsigned int FromType;
12161212 // The last operand holds the original LoadSDNode::getExtensionType() value
1217- unsigned ExtensionType = cast<ConstantSDNode>(
1218- N->getOperand (N->getNumOperands () - 1 ))->getZExtValue ();
1219- if (ExtensionType == ISD::SEXTLOAD)
1220- FromType = NVPTX::PTXLdStInstCode::Signed;
1221- else
1222- FromType = getLdStRegType (ScalarVT);
1213+ const unsigned TotalWidth = MemVT.getSizeInBits ();
1214+ unsigned ExtensionType = N->getConstantOperandVal (N->getNumOperands () - 1 );
1215+ unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
1216+ ? NVPTX::PTXLdStInstCode::Signed
1217+ : getLdStRegType (MemVT.getScalarType ());
12231218
12241219 unsigned VecType;
1225-
1220+ unsigned FromTypeWidth;
12261221 switch (N->getOpcode ()) {
12271222 case NVPTXISD::LoadV2:
1223+ FromTypeWidth = TotalWidth / 2 ;
12281224 VecType = NVPTX::PTXLdStInstCode::V2;
12291225 break ;
12301226 case NVPTXISD::LoadV4:
1227+ FromTypeWidth = TotalWidth / 4 ;
12311228 VecType = NVPTX::PTXLdStInstCode::V4;
12321229 break ;
12331230 default :
12341231 return false ;
12351232 }
12361233
1237- EVT EltVT = N->getValueType (0 );
1238-
1239- if (isVectorElementTypeUpsized (EltVT)) {
1234+ if (isSubVectorPackedInI32 (EltVT)) {
12401235 EltVT = MVT::i32 ;
12411236 FromType = NVPTX::PTXLdStInstCode::Untyped;
1242- FromTypeWidth = 32 ;
12431237 }
12441238
12451239 SDValue Offset, Base;
@@ -1289,9 +1283,14 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12891283 // LDG/LDU SD node (from custom vector handling), then its the second operand
12901284 SDValue Op1 = N->getOperand (N->getOpcode () == ISD::INTRINSIC_W_CHAIN ? 2 : 1 );
12911285
1292- EVT OrigType = N->getValueType (0 );
1286+ const EVT OrigType = N->getValueType (0 );
12931287 EVT EltVT = Mem->getMemoryVT ();
12941288 unsigned NumElts = 1 ;
1289+
1290+ if (EltVT == MVT::i128 || EltVT == MVT::f128 ) {
1291+ EltVT = MVT::i64 ;
1292+ NumElts = 2 ;
1293+ }
12951294 if (EltVT.isVector ()) {
12961295 NumElts = EltVT.getVectorNumElements ();
12971296 EltVT = EltVT.getVectorElementType ();
@@ -1311,11 +1310,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13111310 // Build the "promoted" result VTList for the load. If we are really loading
13121311 // i8s, then the return type will be promoted to i16 since we do not expose
13131312 // 8-bit registers in NVPTX.
1314- EVT NodeVT = (EltVT == MVT::i8 ) ? MVT::i16 : EltVT;
1313+ const EVT NodeVT = (EltVT == MVT::i8 ) ? MVT::i16 : EltVT;
13151314 SmallVector<EVT, 5 > InstVTs;
1316- for (unsigned i = 0 ; i != NumElts; ++i) {
1317- InstVTs.push_back (NodeVT);
1318- }
1315+ InstVTs.append (NumElts, NodeVT);
13191316 InstVTs.push_back (MVT::Other);
13201317 SDVTList InstVTList = CurDAG->getVTList (InstVTs);
13211318 SDValue Chain = N->getOperand (0 );
@@ -1494,6 +1491,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14941491 EVT EltVT = Op1.getValueType ();
14951492 MemSDNode *MemSD = cast<MemSDNode>(N);
14961493 EVT StoreVT = MemSD->getMemoryVT ();
1494+ assert (StoreVT.isSimple () && " Store value is not simple" );
14971495
14981496 // Address Space Setting
14991497 unsigned CodeAddrSpace = getCodeAddrSpace (MemSD);
@@ -1508,35 +1506,35 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15081506
15091507 // Type Setting: toType + toTypeWidth
15101508 // - for integer type, always use 'u'
1511- assert (StoreVT.isSimple () && " Store value is not simple" );
1512- MVT ScalarVT = StoreVT.getSimpleVT ().getScalarType ();
1513- unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
1514- unsigned ToType = getLdStRegType (ScalarVT);
1509+ const unsigned TotalWidth = StoreVT.getSimpleVT ().getSizeInBits ();
1510+ unsigned ToType = getLdStRegType (StoreVT.getSimpleVT ().getScalarType ());
15151511
15161512 SmallVector<SDValue, 12 > Ops;
15171513 SDValue N2;
15181514 unsigned VecType;
1515+ unsigned ToTypeWidth;
15191516
15201517 switch (N->getOpcode ()) {
15211518 case NVPTXISD::StoreV2:
15221519 VecType = NVPTX::PTXLdStInstCode::V2;
15231520 Ops.append ({N->getOperand (1 ), N->getOperand (2 )});
15241521 N2 = N->getOperand (3 );
1522+ ToTypeWidth = TotalWidth / 2 ;
15251523 break ;
15261524 case NVPTXISD::StoreV4:
15271525 VecType = NVPTX::PTXLdStInstCode::V4;
15281526 Ops.append ({N->getOperand (1 ), N->getOperand (2 ), N->getOperand (3 ),
15291527 N->getOperand (4 )});
15301528 N2 = N->getOperand (5 );
1529+ ToTypeWidth = TotalWidth / 4 ;
15311530 break ;
15321531 default :
15331532 return false ;
15341533 }
15351534
1536- if (isVectorElementTypeUpsized (EltVT)) {
1535+ if (isSubVectorPackedInI32 (EltVT)) {
15371536 EltVT = MVT::i32 ;
15381537 ToType = NVPTX::PTXLdStInstCode::Untyped;
1539- ToTypeWidth = 32 ;
15401538 }
15411539
15421540 SDValue Offset, Base;
0 commit comments