Skip to content

Commit 277bc3a

Browse files
committed
[NVPTX] loads, stores of v2f32 are untyped
Ensures ld.b64 and st.b64 for v2f32. Also remove -O3 in f32x2-instructions.ll test.
1 parent b365375 commit 277bc3a

File tree

2 files changed

+1037
-2100
lines changed

2 files changed

+1037
-2100
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,7 @@ static int getLdStRegType(EVT VT) {
10591059
case MVT::bf16:
10601060
case MVT::v2f16:
10611061
case MVT::v2bf16:
1062+
case MVT::v2f32:
10621063
return NVPTX::PTXLdStInstCode::Untyped;
10631064
default:
10641065
return NVPTX::PTXLdStInstCode::Float;
@@ -1097,24 +1098,27 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10971098
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
10981099
MVT SimpleVT = LoadedVT.getSimpleVT();
10991100
MVT ScalarVT = SimpleVT.getScalarType();
1100-
// Read at least 8 bits (predicates are stored as 8-bit values)
1101-
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1102-
unsigned int FromType;
11031101

11041102
// Vector Setting
11051103
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
11061104
if (SimpleVT.isVector()) {
1107-
if (Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8)
1108-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1109-
FromTypeWidth = 32;
1110-
else if (LoadedVT == MVT::v2f32)
1111-
// v2f32 is loaded using ld.b64
1112-
FromTypeWidth = 64;
1113-
else
1114-
llvm_unreachable("Unexpected vector type");
1105+
switch (LoadedVT.getSimpleVT().SimpleTy) {
1106+
case MVT::v2f16:
1107+
case MVT::v2bf16:
1108+
case MVT::v2i16:
1109+
case MVT::v4i8:
1110+
case MVT::v2f32:
1111+
ScalarVT = LoadedVT.getSimpleVT();
1112+
break;
1113+
default:
1114+
llvm_unreachable("Unsupported vector type for non-vector load");
1115+
}
11151116
}
11161117

1117-
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
1118+
// Read at least 8 bits (predicates are stored as 8-bit values)
1119+
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1120+
unsigned int FromType;
1121+
if (PlainLoad && PlainLoad->getExtensionType() == ISD::SEXTLOAD)
11181122
FromType = NVPTX::PTXLdStInstCode::Signed;
11191123
else
11201124
FromType = getLdStRegType(ScalarVT);
@@ -1424,18 +1428,21 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14241428
// Type Setting: toType + toTypeWidth
14251429
// - for integer type, always use 'u'
14261430
MVT ScalarVT = SimpleVT.getScalarType();
1427-
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
14281431
if (SimpleVT.isVector()) {
1429-
if (Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8)
1430-
// v2x16 is stored using st.b32
1431-
ToTypeWidth = 32;
1432-
else if (StoreVT == MVT::v2f32)
1433-
// v2f32 is stored using st.b64
1434-
ToTypeWidth = 64;
1435-
else
1436-
llvm_unreachable("Unexpected vector type");
1432+
switch (StoreVT.getSimpleVT().SimpleTy) {
1433+
case MVT::v2f16:
1434+
case MVT::v2bf16:
1435+
case MVT::v2i16:
1436+
case MVT::v4i8:
1437+
case MVT::v2f32:
1438+
ScalarVT = StoreVT.getSimpleVT();
1439+
break;
1440+
default:
1441+
llvm_unreachable("Unsupported vector type for non-vector store");
1442+
}
14371443
}
14381444

1445+
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
14391446
unsigned int ToType = getLdStRegType(ScalarVT);
14401447

14411448
// Create the machine instruction DAG

0 commit comments

Comments
 (0)