@@ -1074,6 +1074,7 @@ static int getLdStRegType(EVT VT) {
10741074 case MVT::bf16 :
10751075 case MVT::v2f16:
10761076 case MVT::v2bf16:
1077+ case MVT::v2f32:
10771078 return NVPTX::PTXLdStInstCode::Untyped;
10781079 default :
10791080 return NVPTX::PTXLdStInstCode::Float;
@@ -1113,24 +1114,27 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11131114 // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
11141115 MVT SimpleVT = LoadedVT.getSimpleVT ();
11151116 MVT ScalarVT = SimpleVT.getScalarType ();
1116- // Read at least 8 bits (predicates are stored as 8-bit values)
1117- unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1118- unsigned int FromType;
11191117
11201118 // Vector Setting
11211119 unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
11221120 if (SimpleVT.isVector ()) {
1123- if (Isv2x16VT (LoadedVT) || LoadedVT == MVT::v4i8)
1124- // v2f16/v2bf16/v2i16 is loaded using ld.b32
1125- FromTypeWidth = 32 ;
1126- else if (LoadedVT == MVT::v2f32)
1127- // v2f32 is loaded using ld.b64
1128- FromTypeWidth = 64 ;
1129- else
1130- llvm_unreachable (" Unexpected vector type" );
1121+ switch (LoadedVT.getSimpleVT ().SimpleTy ) {
1122+ case MVT::v2f16:
1123+ case MVT::v2bf16:
1124+ case MVT::v2i16:
1125+ case MVT::v4i8:
1126+ case MVT::v2f32:
1127+ ScalarVT = LoadedVT.getSimpleVT ();
1128+ break ;
1129+ default :
1130+ llvm_unreachable (" Unsupported vector type for non-vector load" );
1131+ }
11311132 }
11321133
1133- if (PlainLoad && (PlainLoad->getExtensionType () == ISD::SEXTLOAD))
1134+ // Read at least 8 bits (predicates are stored as 8-bit values)
1135+ unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1136+ unsigned int FromType;
1137+ if (PlainLoad && PlainLoad->getExtensionType () == ISD::SEXTLOAD)
11341138 FromType = NVPTX::PTXLdStInstCode::Signed;
11351139 else
11361140 FromType = getLdStRegType (ScalarVT);
@@ -1438,18 +1442,21 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14381442 // Type Setting: toType + toTypeWidth
14391443 // - for integer type, always use 'u'
14401444 MVT ScalarVT = SimpleVT.getScalarType ();
1441- unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
14421445 if (SimpleVT.isVector ()) {
1443- if (Isv2x16VT (StoreVT) || StoreVT == MVT::v4i8)
1444- // v2x16 is stored using st.b32
1445- ToTypeWidth = 32 ;
1446- else if (StoreVT == MVT::v2f32)
1447- // v2f32 is stored using st.b64
1448- ToTypeWidth = 64 ;
1449- else
1450- llvm_unreachable (" Unexpected vector type" );
1446+ switch (StoreVT.getSimpleVT ().SimpleTy ) {
1447+ case MVT::v2f16:
1448+ case MVT::v2bf16:
1449+ case MVT::v2i16:
1450+ case MVT::v4i8:
1451+ case MVT::v2f32:
1452+ ScalarVT = StoreVT.getSimpleVT ();
1453+ break ;
1454+ default :
1455+ llvm_unreachable (" Unsupported vector type for non-vector store" );
1456+ }
14511457 }
14521458
1459+ unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
14531460 unsigned int ToType = getLdStRegType (ScalarVT);
14541461
14551462 // Create the machine instruction DAG
0 commit comments