@@ -1060,6 +1060,7 @@ static int getLdStRegType(EVT VT) {
10601060 case MVT::bf16 :
10611061 case MVT::v2f16:
10621062 case MVT::v2bf16:
1063+ case MVT::v2f32:
10631064 return NVPTX::PTXLdStInstCode::Untyped;
10641065 default :
10651066 return NVPTX::PTXLdStInstCode::Float;
@@ -1099,24 +1100,27 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10991100 // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
11001101 MVT SimpleVT = LoadedVT.getSimpleVT ();
11011102 MVT ScalarVT = SimpleVT.getScalarType ();
1102- // Read at least 8 bits (predicates are stored as 8-bit values)
1103- unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1104- unsigned int FromType;
11051103
11061104 // Vector Setting
11071105 unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
11081106 if (SimpleVT.isVector ()) {
1109- if (Isv2x16VT (LoadedVT) || LoadedVT == MVT::v4i8)
1110- // v2f16/v2bf16/v2i16 is loaded using ld.b32
1111- FromTypeWidth = 32 ;
1112- else if (LoadedVT == MVT::v2f32)
1113- // v2f32 is loaded using ld.b64
1114- FromTypeWidth = 64 ;
1115- else
1116- llvm_unreachable (" Unexpected vector type" );
1107+ switch (LoadedVT.getSimpleVT ().SimpleTy ) {
1108+ case MVT::v2f16:
1109+ case MVT::v2bf16:
1110+ case MVT::v2i16:
1111+ case MVT::v4i8:
1112+ case MVT::v2f32:
1113+ ScalarVT = LoadedVT.getSimpleVT ();
1114+ break ;
1115+ default :
1116+ llvm_unreachable (" Unsupported vector type for non-vector load" );
1117+ }
11171118 }
11181119
1119- if (PlainLoad && (PlainLoad->getExtensionType () == ISD::SEXTLOAD))
1120+ // Read at least 8 bits (predicates are stored as 8-bit values)
1121+ unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1122+ unsigned int FromType;
1123+ if (PlainLoad && PlainLoad->getExtensionType () == ISD::SEXTLOAD)
11201124 FromType = NVPTX::PTXLdStInstCode::Signed;
11211125 else
11221126 FromType = getLdStRegType (ScalarVT);
@@ -1424,18 +1428,21 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14241428 // Type Setting: toType + toTypeWidth
14251429 // - for integer type, always use 'u'
14261430 MVT ScalarVT = SimpleVT.getScalarType ();
1427- unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
14281431 if (SimpleVT.isVector ()) {
1429- if (Isv2x16VT (StoreVT) || StoreVT == MVT::v4i8)
1430- // v2x16 is stored using st.b32
1431- ToTypeWidth = 32 ;
1432- else if (StoreVT == MVT::v2f32)
1433- // v2f32 is stored using st.b64
1434- ToTypeWidth = 64 ;
1435- else
1436- llvm_unreachable (" Unexpected vector type" );
1432+ switch (StoreVT.getSimpleVT ().SimpleTy ) {
1433+ case MVT::v2f16:
1434+ case MVT::v2bf16:
1435+ case MVT::v2i16:
1436+ case MVT::v4i8:
1437+ case MVT::v2f32:
1438+ ScalarVT = StoreVT.getSimpleVT ();
1439+ break ;
1440+ default :
1441+ llvm_unreachable (" Unsupported vector type for non-vector store" );
1442+ }
14371443 }
14381444
1445+ unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
14391446 unsigned int ToType = getLdStRegType (ScalarVT);
14401447
14411448 // Create the machine instruction DAG
0 commit comments