@@ -1059,6 +1059,7 @@ static int getLdStRegType(EVT VT) {
1059
1059
case MVT::bf16 :
1060
1060
case MVT::v2f16:
1061
1061
case MVT::v2bf16:
1062
+ case MVT::v2f32:
1062
1063
return NVPTX::PTXLdStInstCode::Untyped;
1063
1064
default :
1064
1065
return NVPTX::PTXLdStInstCode::Float;
@@ -1097,24 +1098,27 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
1097
1098
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1098
1099
MVT SimpleVT = LoadedVT.getSimpleVT ();
1099
1100
MVT ScalarVT = SimpleVT.getScalarType ();
1100
- // Read at least 8 bits (predicates are stored as 8-bit values)
1101
- unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1102
- unsigned int FromType;
1103
1101
1104
1102
// Vector Setting
1105
1103
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
1106
1104
if (SimpleVT.isVector ()) {
1107
- if (Isv2x16VT (LoadedVT) || LoadedVT == MVT::v4i8)
1108
- // v2f16/v2bf16/v2i16 is loaded using ld.b32
1109
- FromTypeWidth = 32 ;
1110
- else if (LoadedVT == MVT::v2f32)
1111
- // v2f32 is loaded using ld.b64
1112
- FromTypeWidth = 64 ;
1113
- else
1114
- llvm_unreachable (" Unexpected vector type" );
1105
+ switch (LoadedVT.getSimpleVT ().SimpleTy ) {
1106
+ case MVT::v2f16:
1107
+ case MVT::v2bf16:
1108
+ case MVT::v2i16:
1109
+ case MVT::v4i8:
1110
+ case MVT::v2f32:
1111
+ ScalarVT = LoadedVT.getSimpleVT ();
1112
+ break ;
1113
+ default :
1114
+ llvm_unreachable (" Unsupported vector type for non-vector load" );
1115
+ }
1115
1116
}
1116
1117
1117
- if (PlainLoad && (PlainLoad->getExtensionType () == ISD::SEXTLOAD))
1118
+ // Read at least 8 bits (predicates are stored as 8-bit values)
1119
+ unsigned FromTypeWidth = std::max (8U , (unsigned )ScalarVT.getSizeInBits ());
1120
+ unsigned int FromType;
1121
+ if (PlainLoad && PlainLoad->getExtensionType () == ISD::SEXTLOAD)
1118
1122
FromType = NVPTX::PTXLdStInstCode::Signed;
1119
1123
else
1120
1124
FromType = getLdStRegType (ScalarVT);
@@ -1424,18 +1428,21 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
1424
1428
// Type Setting: toType + toTypeWidth
1425
1429
// - for integer type, always use 'u'
1426
1430
MVT ScalarVT = SimpleVT.getScalarType ();
1427
- unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
1428
1431
if (SimpleVT.isVector ()) {
1429
- if (Isv2x16VT (StoreVT) || StoreVT == MVT::v4i8)
1430
- // v2x16 is stored using st.b32
1431
- ToTypeWidth = 32 ;
1432
- else if (StoreVT == MVT::v2f32)
1433
- // v2f32 is stored using st.b64
1434
- ToTypeWidth = 64 ;
1435
- else
1436
- llvm_unreachable (" Unexpected vector type" );
1432
+ switch (StoreVT.getSimpleVT ().SimpleTy ) {
1433
+ case MVT::v2f16:
1434
+ case MVT::v2bf16:
1435
+ case MVT::v2i16:
1436
+ case MVT::v4i8:
1437
+ case MVT::v2f32:
1438
+ ScalarVT = StoreVT.getSimpleVT ();
1439
+ break ;
1440
+ default :
1441
+ llvm_unreachable (" Unsupported vector type for non-vector store" );
1442
+ }
1437
1443
}
1438
1444
1445
+ unsigned ToTypeWidth = ScalarVT.getSizeInBits ();
1439
1446
unsigned int ToType = getLdStRegType (ScalarVT);
1440
1447
1441
1448
// Create the machine instruction DAG
0 commit comments