@@ -162,6 +162,14 @@ static bool IsPTXVectorType(MVT VT) {
162162 case MVT::v2f32:
163163 case MVT::v4f32:
164164 case MVT::v2f64:
165+ case MVT::v4i64:
166+ case MVT::v4f64:
167+ case MVT::v8i32:
168+ case MVT::v8f32:
169+ case MVT::v16f16: // <8 x f16x2>
170+ case MVT::v16bf16: // <8 x bf16x2>
171+ case MVT::v16i16: // <8 x i16x2>
172+ case MVT::v32i8: // <8 x i8x4>
165173 return true ;
166174 }
167175}
@@ -179,7 +187,7 @@ static bool Is16bitsType(MVT VT) {
179187// - unsigned int NumElts - The number of elements in the final vector
180188// - EVT EltVT - The type of the elements in the final vector
181189static std::optional<std::pair<unsigned int , MVT>>
182- getVectorLoweringShape (EVT VectorEVT) {
190+ getVectorLoweringShape (EVT VectorEVT, bool CanLowerTo256Bit ) {
183191 if (!VectorEVT.isSimple ())
184192 return std::nullopt ;
185193 const MVT VectorVT = VectorEVT.getSimpleVT ();
@@ -199,6 +207,15 @@ getVectorLoweringShape(EVT VectorEVT) {
199207 switch (VectorVT.SimpleTy ) {
200208 default :
201209 return std::nullopt ;
210+ case MVT::v4i64:
211+ case MVT::v4f64:
212+ case MVT::v8i32:
213+ case MVT::v8f32:
214+ // This is a "native" vector type iff the address space is global
215+ // and the target supports 256-bit loads/stores
216+ if (!CanLowerTo256Bit)
217+ return std::nullopt ;
218+ LLVM_FALLTHROUGH;
202219 case MVT::v2i8:
203220 case MVT::v2i16:
204221 case MVT::v2i32:
@@ -215,6 +232,15 @@ getVectorLoweringShape(EVT VectorEVT) {
215232 case MVT::v4f32:
216233 // This is a "native" vector type
217234 return std::pair (NumElts, EltVT);
235+ case MVT::v16f16: // <8 x f16x2>
236+ case MVT::v16bf16: // <8 x bf16x2>
237+ case MVT::v16i16: // <8 x i16x2>
238+ case MVT::v32i8: // <8 x i8x4>
239+ // This can be upsized into a "native" vector type iff the address space is
240+ // global and the target supports 256-bit loads/stores.
241+ if (!CanLowerTo256Bit)
242+ return std::nullopt ;
243+ LLVM_FALLTHROUGH;
218244 case MVT::v8i8: // <2 x i8x4>
219245 case MVT::v8f16: // <4 x f16x2>
220246 case MVT::v8bf16: // <4 x bf16x2>
@@ -1070,10 +1096,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10701096 MAKE_CASE (NVPTXISD::ProxyReg)
10711097 MAKE_CASE (NVPTXISD::LoadV2)
10721098 MAKE_CASE (NVPTXISD::LoadV4)
1099+ MAKE_CASE (NVPTXISD::LoadV8)
10731100 MAKE_CASE (NVPTXISD::LDUV2)
10741101 MAKE_CASE (NVPTXISD::LDUV4)
10751102 MAKE_CASE (NVPTXISD::StoreV2)
10761103 MAKE_CASE (NVPTXISD::StoreV4)
1104+ MAKE_CASE (NVPTXISD::StoreV8)
10771105 MAKE_CASE (NVPTXISD::FSHL_CLAMP)
10781106 MAKE_CASE (NVPTXISD::FSHR_CLAMP)
10791107 MAKE_CASE (NVPTXISD::BFE)
@@ -3201,7 +3229,12 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32013229 if (ValVT != MemVT)
32023230 return SDValue ();
32033231
3204- const auto NumEltsAndEltVT = getVectorLoweringShape (ValVT);
3232+ // 256-bit vectors are only allowed iff the address is global
3233+ // and the target supports 256-bit loads/stores
3234+ unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace ();
3235+ bool CanLowerTo256Bit =
3236+ AddrSpace == ADDRESS_SPACE_GLOBAL && STI.has256BitMaskedLoadStore ();
3237+ const auto NumEltsAndEltVT = getVectorLoweringShape (ValVT, CanLowerTo256Bit);
32053238 if (!NumEltsAndEltVT)
32063239 return SDValue ();
32073240 const auto [NumElts, EltVT] = NumEltsAndEltVT.value ();
@@ -3229,6 +3262,9 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32293262 case 4 :
32303263 Opcode = NVPTXISD::StoreV4;
32313264 break ;
3265+ case 8 :
3266+ Opcode = NVPTXISD::StoreV8;
3267+ break ;
32323268 }
32333269
32343270 SmallVector<SDValue, 8 > Ops;
@@ -5765,7 +5801,8 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
57655801
57665802// / ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
57675803static void ReplaceLoadVector (SDNode *N, SelectionDAG &DAG,
5768- SmallVectorImpl<SDValue> &Results) {
5804+ SmallVectorImpl<SDValue> &Results,
5805+ bool TargetHas256BitVectorLoadStore) {
57695806 LoadSDNode *LD = cast<LoadSDNode>(N);
57705807 const EVT ResVT = LD->getValueType (0 );
57715808 const EVT MemVT = LD->getMemoryVT ();
@@ -5775,7 +5812,12 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
57755812 if (ResVT != MemVT)
57765813 return ;
57775814
5778- const auto NumEltsAndEltVT = getVectorLoweringShape (ResVT);
5815+ // 256-bit vectors are only allowed iff the address is global
5816+ // and the target supports 256-bit loads/stores
5817+ unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace ();
5818+ bool CanLowerTo256Bit =
5819+ AddrSpace == ADDRESS_SPACE_GLOBAL && TargetHas256BitVectorLoadStore;
5820+ const auto NumEltsAndEltVT = getVectorLoweringShape (ResVT, CanLowerTo256Bit);
57795821 if (!NumEltsAndEltVT)
57805822 return ;
57815823 const auto [NumElts, EltVT] = NumEltsAndEltVT.value ();
@@ -5812,6 +5854,13 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
58125854 DAG.getVTList ({LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other});
58135855 break ;
58145856 }
5857+ case 8 : {
5858+ Opcode = NVPTXISD::LoadV8;
5859+ EVT ListVTs[] = {LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT,
5860+ LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other};
5861+ LdResVTs = DAG.getVTList (ListVTs);
5862+ break ;
5863+ }
58155864 }
58165865 SDLoc DL (LD);
58175866
@@ -6084,7 +6133,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
60846133 ReplaceBITCAST (N, DAG, Results);
60856134 return ;
60866135 case ISD::LOAD:
6087- ReplaceLoadVector (N, DAG, Results);
6136+ ReplaceLoadVector (N, DAG, Results, STI. has256BitMaskedLoadStore () );
60886137 return ;
60896138 case ISD::INTRINSIC_W_CHAIN:
60906139 ReplaceINTRINSIC_W_CHAIN (N, DAG, Results);
0 commit comments