@@ -1168,60 +1168,54 @@ static bool isVectorElementTypeUpsized(EVT EltVT) {
11681168
11691169bool NVPTXDAGToDAGISel::tryLoadVector (SDNode *N) {
11701170 MemSDNode *MemSD = cast<MemSDNode>(N);
1171- EVT LoadedVT = MemSD->getMemoryVT ();
1172- if (!LoadedVT .isSimple ())
1171+ EVT MemEVT = MemSD->getMemoryVT ();
1172+ if (!MemEVT .isSimple ())
11731173 return false ;
1174+ MVT MemVT = MemEVT.getSimpleVT ();
11741175
11751176 // Address Space Setting
11761177 unsigned int CodeAddrSpace = getCodeAddrSpace (MemSD);
11771178 if (canLowerToLDG (MemSD, *Subtarget, CodeAddrSpace, MF)) {
11781179 return tryLDGLDU (N);
11791180 }
11801181
1182+ EVT EltVT = N->getValueType (0 );
11811183 SDLoc DL (N);
11821184 SDValue Chain = N->getOperand (0 );
11831185 auto [Ordering, Scope] = insertMemoryInstructionFence (DL, Chain, MemSD);
11841186
1185- // Vector Setting
1186- MVT SimpleVT = LoadedVT.getSimpleVT ();
1187-
11881187 // Type Setting: fromType + fromTypeWidth
11891188 //
11901189 // Sign : ISD::SEXTLOAD
11911190 // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
11921191 // type is integer
11931192 // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1194- MVT ScalarVT = SimpleVT.getScalarType ();
11951193 // Read at least 8 bits (predicates are stored as 8-bit values)
1196- unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1197- unsigned int FromType;
11981194 // The last operand holds the original LoadSDNode::getExtensionType() value
1199- unsigned ExtensionType = cast<ConstantSDNode>(
1200- N->getOperand (N->getNumOperands () - 1 ))->getZExtValue ();
1201- if (ExtensionType == ISD::SEXTLOAD)
1202- FromType = NVPTX::PTXLdStInstCode::Signed;
1203- else
1204- FromType = getLdStRegType (ScalarVT);
1195+ const unsigned TotalWidth = MemVT.getSizeInBits ();
1196+ unsigned ExtensionType = N->getConstantOperandVal (N->getNumOperands () - 1 );
1197+ unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
1198+ ? NVPTX::PTXLdStInstCode::Signed
1199+ : getLdStRegType (MemVT.getScalarType ());
12051200
12061201 unsigned VecType;
1207-
1202+ unsigned FromTypeWidth;
12081203 switch (N->getOpcode ()) {
12091204 case NVPTXISD::LoadV2:
1205+ FromTypeWidth = TotalWidth / 2 ;
12101206 VecType = NVPTX::PTXLdStInstCode::V2;
12111207 break ;
12121208 case NVPTXISD::LoadV4:
1209+ FromTypeWidth = TotalWidth / 4 ;
12131210 VecType = NVPTX::PTXLdStInstCode::V4;
12141211 break ;
12151212 default :
12161213 return false ;
12171214 }
12181215
1219- EVT EltVT = N->getValueType (0 );
1220-
12211216 if (isVectorElementTypeUpsized (EltVT)) {
12221217 EltVT = MVT::i32 ;
12231218 FromType = NVPTX::PTXLdStInstCode::Untyped;
1224- FromTypeWidth = 32 ;
12251219 }
12261220
12271221 SDValue Offset, Base;
@@ -1271,9 +1265,14 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12711265 // LDG/LDU SD node (from custom vector handling), then its the second operand
12721266 SDValue Op1 = N->getOperand (N->getOpcode () == ISD::INTRINSIC_W_CHAIN ? 2 : 1 );
12731267
1274- EVT OrigType = N->getValueType (0 );
1268+ const EVT OrigType = N->getValueType (0 );
12751269 EVT EltVT = Mem->getMemoryVT ();
12761270 unsigned NumElts = 1 ;
1271+
1272+ if (EltVT == MVT::i128 || EltVT == MVT::f128 ) {
1273+ EltVT = MVT::i64 ;
1274+ NumElts = 2 ;
1275+ }
12771276 if (EltVT.isVector ()) {
12781277 NumElts = EltVT.getVectorNumElements ();
12791278 EltVT = EltVT.getVectorElementType ();
@@ -1293,11 +1292,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12931292 // Build the "promoted" result VTList for the load. If we are really loading
12941293 // i8s, then the return type will be promoted to i16 since we do not expose
12951294 // 8-bit registers in NVPTX.
1296- EVT NodeVT = (EltVT == MVT::i8 ) ? MVT::i16 : EltVT;
1295+ const EVT NodeVT = (EltVT == MVT::i8 ) ? MVT::i16 : EltVT;
12971296 SmallVector<EVT, 5 > InstVTs;
1298- for (unsigned i = 0 ; i != NumElts; ++i) {
1299- InstVTs.push_back (NodeVT);
1300- }
1297+ InstVTs.append (NumElts, NodeVT);
13011298 InstVTs.push_back (MVT::Other);
13021299 SDVTList InstVTList = CurDAG->getVTList (InstVTs);
13031300 SDValue Chain = N->getOperand (0 );
@@ -1476,6 +1473,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14761473 EVT EltVT = Op1.getValueType ();
14771474 MemSDNode *MemSD = cast<MemSDNode>(N);
14781475 EVT StoreVT = MemSD->getMemoryVT ();
1476+ assert (StoreVT.isSimple () && " Store value is not simple" );
14791477
14801478 // Address Space Setting
14811479 unsigned CodeAddrSpace = getCodeAddrSpace (MemSD);
@@ -1490,35 +1488,35 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14901488
14911489 // Type Setting: toType + toTypeWidth
14921490 // - for integer type, always use 'u'
1493- assert (StoreVT.isSimple () && " Store value is not simple" );
1494- MVT ScalarVT = StoreVT.getSimpleVT ().getScalarType ();
1495- unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
1496- unsigned ToType = getLdStRegType (ScalarVT);
1491+ const unsigned TotalWidth = StoreVT.getSimpleVT ().getSizeInBits ();
1492+ unsigned ToType = getLdStRegType (StoreVT.getSimpleVT ().getScalarType ());
14971493
14981494 SmallVector<SDValue, 12 > Ops;
14991495 SDValue N2;
15001496 unsigned VecType;
1497+ unsigned ToTypeWidth;
15011498
15021499 switch (N->getOpcode ()) {
15031500 case NVPTXISD::StoreV2:
15041501 VecType = NVPTX::PTXLdStInstCode::V2;
15051502 Ops.append ({N->getOperand (1 ), N->getOperand (2 )});
15061503 N2 = N->getOperand (3 );
1504+ ToTypeWidth = TotalWidth / 2 ;
15071505 break ;
15081506 case NVPTXISD::StoreV4:
15091507 VecType = NVPTX::PTXLdStInstCode::V4;
15101508 Ops.append ({N->getOperand (1 ), N->getOperand (2 ), N->getOperand (3 ),
15111509 N->getOperand (4 )});
15121510 N2 = N->getOperand (5 );
1511+ ToTypeWidth = TotalWidth / 4 ;
15131512 break ;
15141513 default :
15151514 return false ;
15161515 }
15171516
15181517 if (isVectorElementTypeUpsized (EltVT)) {
1519- EltVT = MVT::i32 ;
15201518 ToType = NVPTX::PTXLdStInstCode::Untyped;
1521- ToTypeWidth = 32 ;
1519+ EltVT = MVT:: i32 ;
15221520 }
15231521
15241522 SDValue Offset, Base;
0 commit comments