Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 0 additions & 84 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
if (tryLoadParam(N))
return;
break;
case NVPTXISD::StoreRetval:
case NVPTXISD::StoreRetvalV2:
case NVPTXISD::StoreRetvalV4:
if (tryStoreRetval(N))
return;
break;
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
Expand Down Expand Up @@ -1504,84 +1498,6 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
return true;
}

bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
SDLoc DL(N);
SDValue Chain = N->getOperand(0);
SDValue Offset = N->getOperand(1);
unsigned OffsetVal = Offset->getAsZExtVal();
MemSDNode *Mem = cast<MemSDNode>(N);

// How many elements do we have?
unsigned NumElts = 1;
switch (N->getOpcode()) {
default:
return false;
case NVPTXISD::StoreRetval:
NumElts = 1;
break;
case NVPTXISD::StoreRetvalV2:
NumElts = 2;
break;
case NVPTXISD::StoreRetvalV4:
NumElts = 4;
break;
}

// Build vector of operands
SmallVector<SDValue, 6> Ops;
for (unsigned i = 0; i < NumElts; ++i)
Ops.push_back(N->getOperand(i + 2));
Ops.append({CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain});

// Determine target opcode
// If we have an i1, use an 8-bit store. The lowering code in
// NVPTXISelLowering will have already emitted an upcast.
std::optional<unsigned> Opcode = 0;
switch (NumElts) {
default:
return false;
case 1:
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
NVPTX::StoreRetvalI8, NVPTX::StoreRetvalI16,
NVPTX::StoreRetvalI32, NVPTX::StoreRetvalI64);
if (Opcode == NVPTX::StoreRetvalI8) {
// Fine tune the opcode depending on the size of the operand.
// This helps to avoid creating redundant COPY instructions in
// InstrEmitter::AddRegisterOperand().
switch (Ops[0].getSimpleValueType().SimpleTy) {
default:
break;
case MVT::i32:
Opcode = NVPTX::StoreRetvalI8TruncI32;
break;
case MVT::i64:
Opcode = NVPTX::StoreRetvalI8TruncI64;
break;
}
}
break;
case 2:
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
NVPTX::StoreRetvalV2I8, NVPTX::StoreRetvalV2I16,
NVPTX::StoreRetvalV2I32, NVPTX::StoreRetvalV2I64);
break;
case 4:
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
NVPTX::StoreRetvalV4I8, NVPTX::StoreRetvalV4I16,
NVPTX::StoreRetvalV4I32, {/* no v4i64 */});
break;
}
if (!Opcode)
return false;

SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops);
MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});

ReplaceNode(N, Ret);
return true;
}

// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
#define getOpcV2H(ty, opKind0, opKind1) \
NVPTX::StoreParamV2##ty##_##opKind0##opKind1
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryStore(SDNode *N);
bool tryStoreVector(SDNode *N);
bool tryLoadParam(SDNode *N);
bool tryStoreRetval(SDNode *N);
bool tryStoreParam(SDNode *N);
bool tryFence(SDNode *N);
void SelectAddrSpaceCast(SDNode *N);
Expand Down
177 changes: 57 additions & 120 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
} else if (EltVT.getSimpleVT() == MVT::i8 && NumElts == 2) {
// v2i8 is promoted to v2i16
NumElts = 1;
EltVT = MVT::v2i16;
EltVT = MVT::v2i8;
}
for (unsigned j = 0; j != NumElts; ++j) {
ValueVTs.push_back(EltVT);
Expand Down Expand Up @@ -1065,9 +1065,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::StoreParamV2)
MAKE_CASE(NVPTXISD::StoreParamV4)
MAKE_CASE(NVPTXISD::MoveParam)
MAKE_CASE(NVPTXISD::StoreRetval)
MAKE_CASE(NVPTXISD::StoreRetvalV2)
MAKE_CASE(NVPTXISD::StoreRetvalV4)
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
MAKE_CASE(NVPTXISD::CallPrototype)
Expand Down Expand Up @@ -1438,7 +1435,11 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
}

static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
return Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
if (Flags.isSExt())
return ISD::SIGN_EXTEND;
if (Flags.isZExt())
return ISD::ZERO_EXTEND;
return ISD::ANY_EXTEND;
}

SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Expand Down Expand Up @@ -3373,10 +3374,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
}
InVals.push_back(P);
} else {
bool aggregateIsPacked = false;
if (StructType *STy = dyn_cast<StructType>(Ty))
aggregateIsPacked = STy->isPacked();

SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
Expand All @@ -3389,9 +3386,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
unsigned I = 0;
for (const unsigned NumElts : VectorInfo) {
const EVT EltVT = VTs[I];
// i1 is loaded/stored as i8
const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
// If the element is a packed type (ex. v2f16, v4i8, etc) holding
// multiple elements.
const unsigned PackingAmt =
Expand All @@ -3403,14 +3399,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
SDValue VecAddr = DAG.getObjectPtrOffset(
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));

const MaybeAlign PartAlign = [&]() -> MaybeAlign {
if (aggregateIsPacked)
return Align(1);
if (NumElts != 1)
return std::nullopt;
Align PartAlign = DAG.getEVTAlign(EltVT);
return commonAlignment(PartAlign, Offsets[I]);
}();
const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]);
SDValue P =
DAG.getLoad(VecVT, dl, Root, VecAddr,
MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
Expand All @@ -3419,23 +3408,22 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (P.getNode())
P.getNode()->setIROrder(Arg.getArgNo() + 1);
for (const unsigned J : llvm::seq(NumElts)) {
SDValue Elt = DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
: ISD::EXTRACT_VECTOR_ELT,
dl, LoadVT, P,
DAG.getIntPtrConstant(J * PackingAmt, dl));
SDValue Elt = DAG.getNode(
LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
: ISD::EXTRACT_VECTOR_ELT,
dl, LoadVT, P, DAG.getVectorIdxConstant(J * PackingAmt, dl));

// Extend or truncate the element if necessary (e.g. an i8 is loaded
// into an i16 register)
const EVT ExpactedVT = ArgIns[I + J].VT;
assert((Elt.getValueType().bitsEq(ExpactedVT) ||
(ExpactedVT.isScalarInteger() &&
Elt.getValueType().isScalarInteger())) &&
const EVT ExpectedVT = ArgIns[I + J].VT;
assert((Elt.getValueType() == ExpectedVT ||
(ExpectedVT.isInteger() && Elt.getValueType().isInteger())) &&
"Non-integer argument type size mismatch");
if (ExpactedVT.bitsGT(Elt.getValueType()))
Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpactedVT,
if (ExpectedVT.bitsGT(Elt.getValueType()))
Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpectedVT,
Elt);
else if (ExpactedVT.bitsLT(Elt.getValueType()))
Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpactedVT, Elt);
else if (ExpectedVT.bitsLT(Elt.getValueType()))
Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, Elt);
InVals.push_back(Elt);
}
I += NumElts;
Expand All @@ -3449,33 +3437,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
return Chain;
}

// Use byte-store when the param adress of the return value is unaligned.
// This may happen when the return value is a field of a packed structure.
static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain,
uint64_t Offset, EVT ElementType,
SDValue RetVal, const SDLoc &dl) {
// Bit logic only works on integer types
if (adjustElementType(ElementType))
RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);

// Store each byte
for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
// Shift the byte to the last byte position
SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
DAG.getConstant(i * 8, dl, MVT::i32));
SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
ShiftVal};
// Trunc store only the last byte by using
// st.param.b8
// The register type can be larger than b8.
Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
DAG.getVTList(MVT::Other), StoreOperands,
MVT::i8, MachinePointerInfo(), std::nullopt,
MachineMemOperand::MOStore);
}
return Chain;
}

SDValue
NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
bool isVarArg,
Expand All @@ -3497,10 +3458,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");

for (const unsigned I : llvm::seq(VTs.size()))
if (const auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
VTs[I] = *PromotedVT;

// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
// they are signed or unsigned types.
Expand All @@ -3512,12 +3469,20 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
assert(!PromoteScalarIntegerPTX(RetVal.getValueType()) &&
"OutVal type should always be legal");

if (ExtendIntegerRetVal) {
RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, MVT::i32, RetVal);
} else if (RetVal.getValueSizeInBits() < 16) {
// Use 16-bit registers for small load-stores as it's the
// smallest general purpose register size supported by NVPTX.
RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
EVT VTI = VTs[I];
if (const auto PromotedVT = PromoteScalarIntegerPTX(VTI))
VTI = *PromotedVT;

const EVT StoreVT =
ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);

assert((RetVal.getValueType() == StoreVT ||
(StoreVT.isInteger() && RetVal.getValueType().isInteger())) &&
"Non-integer argument type size mismatch");
if (StoreVT.bitsGT(RetVal.getValueType())) {
RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, StoreVT, RetVal);
} else if (StoreVT.bitsLT(RetVal.getValueType())) {
RetVal = DAG.getNode(ISD::TRUNCATE, dl, StoreVT, RetVal);
}
return RetVal;
};
Expand All @@ -3526,45 +3491,34 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
unsigned I = 0;
for (const unsigned NumElts : VectorInfo) {
const Align CurrentAlign = commonAlignment(RetAlign, Offsets[I]);
if (NumElts == 1 && RetTy->isAggregateType() &&
CurrentAlign < DAG.getEVTAlign(VTs[I])) {
Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], VTs[I],
GetRetVal(I), dl);

// The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
// into the graph, so just move on to the next element.
I++;
continue;
}
const MaybeAlign CurrentAlign = ExtendIntegerRetVal
? MaybeAlign(std::nullopt)
: commonAlignment(RetAlign, Offsets[I]);

SmallVector<SDValue, 6> StoreOperands{
Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)};

for (const unsigned J : llvm::seq(NumElts))
StoreOperands.push_back(GetRetVal(I + J));
SDValue Val;
if (NumElts == 1) {
Val = GetRetVal(I);
} else {
SmallVector<SDValue, 6> StoreVals;
for (const unsigned J : llvm::seq(NumElts)) {
SDValue ValJ = GetRetVal(I + J);
if (ValJ.getValueType().isVector())
DAG.ExtractVectorElements(ValJ, StoreVals);
else
StoreVals.push_back(ValJ);
}

NVPTXISD::NodeType Op;
switch (NumElts) {
case 1:
Op = NVPTXISD::StoreRetval;
break;
case 2:
Op = NVPTXISD::StoreRetvalV2;
break;
case 4:
Op = NVPTXISD::StoreRetvalV4;
break;
default:
llvm_unreachable("Invalid vector info.");
EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(),
StoreVals.size());
Val = DAG.getBuildVector(VT, dl, StoreVals);
}

// Adjust type of load/store op if we've extended the scalar
// return value.
EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
Chain = DAG.getMemIntrinsicNode(
Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
MachinePointerInfo(), CurrentAlign, MachineMemOperand::MOStore);
SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
SDValue Ptr =
DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));

Chain = DAG.getStore(Chain, dl, Val, Ptr,
MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);

I += NumElts;
}
Expand Down Expand Up @@ -5120,19 +5074,12 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
case NVPTXISD::StoreParamV2:
Opcode = NVPTXISD::StoreParamV4;
break;
case NVPTXISD::StoreRetval:
Opcode = NVPTXISD::StoreRetvalV2;
break;
case NVPTXISD::StoreRetvalV2:
Opcode = NVPTXISD::StoreRetvalV4;
break;
case NVPTXISD::StoreV2:
MemVT = ST->getMemoryVT();
Opcode = NVPTXISD::StoreV4;
break;
case NVPTXISD::StoreV4:
case NVPTXISD::StoreParamV4:
case NVPTXISD::StoreRetvalV4:
case NVPTXISD::StoreV8:
// PTX doesn't support the next doubling of operands
return SDValue();
Expand Down Expand Up @@ -5201,12 +5148,6 @@ static SDValue PerformStoreParamCombine(SDNode *N,
return PerformStoreCombineHelper(N, DCI, 3, 1);
}

static SDValue PerformStoreRetvalCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// Operands from the 2nd to the last one are the values to be stored
return PerformStoreCombineHelper(N, DCI, 2, 0);
}

/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
///
static SDValue PerformADDCombine(SDNode *N,
Expand Down Expand Up @@ -5840,10 +5781,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
return combineUnpackingMovIntoLoad(N, DCI);
case NVPTXISD::StoreRetval:
case NVPTXISD::StoreRetvalV2:
case NVPTXISD::StoreRetvalV4:
return PerformStoreRetvalCombine(N, DCI);
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
Expand Down
Loading