Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,60 +1168,54 @@ static bool isVectorElementTypeUpsized(EVT EltVT) {

bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
MemSDNode *MemSD = cast<MemSDNode>(N);
EVT LoadedVT = MemSD->getMemoryVT();
if (!LoadedVT.isSimple())
const EVT MemEVT = MemSD->getMemoryVT();
if (!MemEVT.isSimple())
return false;
const MVT MemVT = MemEVT.getSimpleVT();

// Address Space Setting
unsigned int CodeAddrSpace = getCodeAddrSpace(MemSD);
if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace, MF)) {
return tryLDGLDU(N);
}

EVT EltVT = N->getValueType(0);
SDLoc DL(N);
SDValue Chain = N->getOperand(0);
auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD);

// Vector Setting
MVT SimpleVT = LoadedVT.getSimpleVT();

// Type Setting: fromType + fromTypeWidth
//
// Sign : ISD::SEXTLOAD
// Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
// type is integer
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
MVT ScalarVT = SimpleVT.getScalarType();
// Read at least 8 bits (predicates are stored as 8-bit values)
unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
unsigned int FromType;
// The last operand holds the original LoadSDNode::getExtensionType() value
unsigned ExtensionType = cast<ConstantSDNode>(
N->getOperand(N->getNumOperands() - 1))->getZExtValue();
if (ExtensionType == ISD::SEXTLOAD)
FromType = NVPTX::PTXLdStInstCode::Signed;
else
FromType = getLdStRegType(ScalarVT);
const unsigned TotalWidth = MemVT.getSizeInBits();
unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
: getLdStRegType(MemVT.getScalarType());

unsigned VecType;

unsigned FromTypeWidth;
switch (N->getOpcode()) {
case NVPTXISD::LoadV2:
FromTypeWidth = TotalWidth / 2;
VecType = NVPTX::PTXLdStInstCode::V2;
break;
case NVPTXISD::LoadV4:
FromTypeWidth = TotalWidth / 4;
VecType = NVPTX::PTXLdStInstCode::V4;
break;
default:
return false;
}

EVT EltVT = N->getValueType(0);

if (isVectorElementTypeUpsized(EltVT)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably rename it to something more meaningful. isSubVectorPackedInI32?
Otherwise it's not clear why we're using i32 here without having to go and look at the implementation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I've switched to isSubVectorPackedInI32. This function could be used several other places as well, but leaving that for a subsequent change.

EltVT = MVT::i32;
FromType = NVPTX::PTXLdStInstCode::Untyped;
FromTypeWidth = 32;
}

SDValue Offset, Base;
Expand Down Expand Up @@ -1271,9 +1265,14 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// LDG/LDU SD node (from custom vector handling), then its the second operand
SDValue Op1 = N->getOperand(N->getOpcode() == ISD::INTRINSIC_W_CHAIN ? 2 : 1);

EVT OrigType = N->getValueType(0);
const EVT OrigType = N->getValueType(0);
EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;

if (EltVT == MVT::i128 || EltVT == MVT::f128) {
EltVT = MVT::i64;
NumElts = 2;
}
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
Expand All @@ -1293,11 +1292,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// Build the "promoted" result VTList for the load. If we are really loading
// i8s, then the return type will be promoted to i16 since we do not expose
// 8-bit registers in NVPTX.
EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
const EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
SmallVector<EVT, 5> InstVTs;
for (unsigned i = 0; i != NumElts; ++i) {
InstVTs.push_back(NodeVT);
}
InstVTs.append(NumElts, NodeVT);
InstVTs.push_back(MVT::Other);
SDVTList InstVTList = CurDAG->getVTList(InstVTs);
SDValue Chain = N->getOperand(0);
Expand Down Expand Up @@ -1476,6 +1473,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
EVT EltVT = Op1.getValueType();
MemSDNode *MemSD = cast<MemSDNode>(N);
EVT StoreVT = MemSD->getMemoryVT();
assert(StoreVT.isSimple() && "Store value is not simple");

// Address Space Setting
unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
Expand All @@ -1490,26 +1488,27 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {

// Type Setting: toType + toTypeWidth
// - for integer type, always use 'u'
assert(StoreVT.isSimple() && "Store value is not simple");
MVT ScalarVT = StoreVT.getSimpleVT().getScalarType();
unsigned ToTypeWidth = ScalarVT.getSizeInBits();
unsigned ToType = getLdStRegType(ScalarVT);
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
unsigned ToType = getLdStRegType(StoreVT.getSimpleVT().getScalarType());

SmallVector<SDValue, 12> Ops;
SDValue N2;
unsigned VecType;
unsigned ToTypeWidth;

switch (N->getOpcode()) {
case NVPTXISD::StoreV2:
VecType = NVPTX::PTXLdStInstCode::V2;
Ops.append({N->getOperand(1), N->getOperand(2)});
N2 = N->getOperand(3);
ToTypeWidth = TotalWidth / 2;
break;
case NVPTXISD::StoreV4:
VecType = NVPTX::PTXLdStInstCode::V4;
Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
N->getOperand(4)});
N2 = N->getOperand(5);
ToTypeWidth = TotalWidth / 4;
break;
default:
return false;
Expand All @@ -1518,7 +1517,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
if (isVectorElementTypeUpsized(EltVT)) {
EltVT = MVT::i32;
ToType = NVPTX::PTXLdStInstCode::Untyped;
ToTypeWidth = 32;
}

SDValue Offset, Base;
Expand Down
Loading
Loading