Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 79 additions & 114 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,6 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
return VectorInfo;
}

static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
SDValue Value) {
if (Value->getValueType(0) == VT)
return Value;
return DAG.getNode(ISD::BITCAST, DL, VT, Value);
}

// NVPTXTargetLowering Constructor.
NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
const NVPTXSubtarget &STI)
Expand Down Expand Up @@ -1587,9 +1580,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,

auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
SmallVector<SDValue, 6> StoreOperands;
for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
EVT EltVT = VTs[j];
int CurOffset = Offsets[j];
for (const unsigned J : llvm::seq(VTs.size())) {
EVT EltVT = VTs[J];
const int CurOffset = Offsets[J];
MaybeAlign PartAlign;
if (NeedAlign)
PartAlign = commonAlignment(ArgAlign, CurOffset);
Expand Down Expand Up @@ -1629,7 +1622,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,

// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
// scalar store. In such cases, fall back to byte stores.
if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
if (VectorInfo[J] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
PartAlign.value() <
DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
assert(StoreOperands.empty() && "Unfinished preceeding store.");
Expand All @@ -1645,7 +1638,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
}

// New store.
if (VectorInfo[j] & PVF_FIRST) {
if (VectorInfo[J] & PVF_FIRST) {
assert(StoreOperands.empty() && "Unfinished preceding store.");
StoreOperands.push_back(Chain);
StoreOperands.push_back(
Expand All @@ -1665,8 +1658,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// Record the value to store.
StoreOperands.push_back(StVal);

if (VectorInfo[j] & PVF_LAST) {
unsigned NumElts = StoreOperands.size() - 3;
if (VectorInfo[J] & PVF_LAST) {
const unsigned NumElts = StoreOperands.size() - 3;
NVPTXISD::NodeType Op;
switch (NumElts) {
case 1:
Expand Down Expand Up @@ -2168,7 +2161,7 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
ISD::OR, DL, MVT::i16,
{Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
EVT ToVT = Op->getValueType(0);
return MaybeBitcast(DAG, DL, ToVT, AsInt);
return DAG.getBitcast(ToVT, AsInt);
}

// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
Expand Down Expand Up @@ -3367,18 +3360,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
auto PtrVT = getPointerTy(DAG.getDataLayout());

const Function *F = &MF.getFunction();
const AttributeList &PAL = F->getAttributes();
const TargetLowering *TLI = STI.getTargetLowering();

SDValue Root = DAG.getRoot();
std::vector<SDValue> OutChains;
SmallVector<SDValue, 16> OutChains;

std::vector<Type *> argTypes;
std::vector<const Argument *> theArgs;
for (const Argument &I : F->args()) {
theArgs.push_back(&I);
argTypes.push_back(I.getType());
}
// argTypes.size() (or theArgs.size()) and Ins.size() need not match.
// Ins.size() will be larger
// * if there is an aggregate argument with multiple fields (each field
Expand All @@ -3388,75 +3373,85 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
// individually present in Ins.
// So a different index should be used for indexing into Ins.
// See similar issue in LowerCall.
unsigned InsIdx = 0;

for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) {
Type *Ty = argTypes[i];
auto AllIns = ArrayRef(Ins);
for (const auto &Arg : F->args()) {
const auto ArgIns = AllIns.take_while(
[&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
AllIns = AllIns.drop_front(ArgIns.size());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. That makes it pretty obvious what we're doing here.


if (theArgs[i]->use_empty()) {
// argument is dead
if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
SmallVector<EVT, 16> vtparts;
Type *Ty = Arg.getType();

ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
if (vtparts.empty())
report_fatal_error("Empty parameter types are not supported");
if (ArgIns.empty())
report_fatal_error("Empty parameter types are not supported");

for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
++parti) {
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
++InsIdx;
}
if (vtparts.size() > 0)
--InsIdx;
continue;
}
if (Ty->isVectorTy()) {
EVT ObjectVT = getValueType(DL, Ty);
unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
for (unsigned parti = 0; parti < NumRegs; ++parti) {
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
++InsIdx;
}
if (NumRegs > 0)
--InsIdx;
continue;
if (Arg.use_empty()) {
// argument is dead
for (const auto &In : ArgIns) {
assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
InVals.push_back(DAG.getUNDEF(In.VT));
}
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
continue;
}

SDValue ArgSymbol = getParamSymbol(DAG, Arg.getArgNo(), PtrVT);

// In the following cases, assign a node order of "i+1"
// to newly created nodes. The SDNodes for params have to
// appear in the same order as their order of appearance
// in the original function. "i+1" holds that order.
if (!PAL.hasParamAttr(i, Attribute::ByVal)) {
if (Arg.hasByValAttr()) {
// Param has ByVal attribute
// Return MoveParam(param symbol).
// Ideally, the param symbol can be returned directly,
// but when SDNode builder decides to use it in a CopyToReg(),
// machine instruction fails because TargetExternalSymbol
// (not lowered) is target dependent, and CopyToReg assumes
// the source is lowered.
assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
const auto &ByvalIn = ArgIns[0];
assert(getValueType(DL, Ty) == ByvalIn.VT &&
"Ins type did not match function type");
assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");

SDValue P;
if (isKernelFunction(*F)) {
P = DAG.getNode(NVPTXISD::Wrapper, dl, ByvalIn.VT, ArgSymbol);
P.getNode()->setIROrder(Arg.getArgNo() + 1);
} else {
P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol);
P.getNode()->setIROrder(Arg.getArgNo() + 1);
P = DAG.getAddrSpaceCast(dl, ByvalIn.VT, P, ADDRESS_SPACE_LOCAL,
ADDRESS_SPACE_GENERIC);
}
InVals.push_back(P);
} else {
bool aggregateIsPacked = false;
if (StructType *STy = dyn_cast<StructType>(Ty))
aggregateIsPacked = STy->isPacked();

SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
if (VTs.empty())
report_fatal_error("Empty parameter types are not supported");
assert(VTs.size() == ArgIns.size() && "Size mismatch");
assert(VTs.size() == Offsets.size() && "Size mismatch");

Align ArgAlign = getFunctionArgumentAlignment(
F, Ty, i + AttributeList::FirstArgIndex, DL);
F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
assert(VectorInfo.size() == VTs.size() && "Size mismatch");

SDValue Arg = getParamSymbol(DAG, i, PtrVT);
int VecIdx = -1; // Index of the first element of the current vector.
for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
if (VectorInfo[parti] & PVF_FIRST) {
for (const unsigned PartI : llvm::seq(VTs.size())) {
if (VectorInfo[PartI] & PVF_FIRST) {
assert(VecIdx == -1 && "Orphaned vector.");
VecIdx = parti;
VecIdx = PartI;
}

// That's the last element of this store op.
if (VectorInfo[parti] & PVF_LAST) {
unsigned NumElts = parti - VecIdx + 1;
EVT EltVT = VTs[parti];
if (VectorInfo[PartI] & PVF_LAST) {
const unsigned NumElts = PartI - VecIdx + 1;
EVT EltVT = VTs[PartI];
// i1 is loaded/stored as i8.
EVT LoadVT = EltVT;
if (EltVT == MVT::i1)
Expand All @@ -3469,10 +3464,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(

EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
SDValue VecAddr =
DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
DAG.getNode(ISD::ADD, dl, PtrVT, ArgSymbol,
DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
Value *srcValue = Constant::getNullValue(
PointerType::get(F->getContext(), ADDRESS_SPACE_PARAM));

const MaybeAlign PartAlign = [&]() -> MaybeAlign {
if (aggregateIsPacked)
Expand All @@ -3481,23 +3474,23 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
return std::nullopt;
Align PartAlign =
DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
return commonAlignment(PartAlign, Offsets[parti]);
return commonAlignment(PartAlign, Offsets[PartI]);
}();
SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
MachinePointerInfo(srcValue), PartAlign,
MachineMemOperand::MODereferenceable |
MachineMemOperand::MOInvariant);
SDValue P =
DAG.getLoad(VecVT, dl, Root, VecAddr,
MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
MachineMemOperand::MODereferenceable |
MachineMemOperand::MOInvariant);
if (P.getNode())
P.getNode()->setIROrder(i + 1);
for (unsigned j = 0; j < NumElts; ++j) {
P.getNode()->setIROrder(Arg.getArgNo() + 1);
for (const unsigned J : llvm::seq(NumElts)) {
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
DAG.getIntPtrConstant(j, dl));
DAG.getIntPtrConstant(J, dl));
// We've loaded i1 as an i8 and now must truncate it back to i1
if (EltVT == MVT::i1)
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
// v2f16 was loaded as an i32. Now we must bitcast it back.
else if (EltVT != LoadVT)
Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
Elt = DAG.getBitcast(EltVT, Elt);

// If a promoted integer type is used, truncate down to the original
MVT PromotedVT;
Expand All @@ -3507,53 +3500,25 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(

// Extend the element if necessary (e.g. an i8 is loaded
// into an i16 register)
if (Ins[InsIdx].VT.isInteger() &&
Ins[InsIdx].VT.getFixedSizeInBits() >
LoadVT.getFixedSizeInBits()) {
unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND;
Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
if (ArgIns[PartI].VT.getFixedSizeInBits() !=
LoadVT.getFixedSizeInBits()) {
assert(ArgIns[PartI].VT.isInteger() && LoadVT.isInteger() &&
"Non-integer argument type size mismatch");
Elt = DAG.getExtOrTrunc(ArgIns[PartI].Flags.isSExt(), Elt, dl,
ArgIns[PartI].VT);
}
InVals.push_back(Elt);
}

// Reset vector tracking state.
VecIdx = -1;
}
++InsIdx;
}
if (VTs.size() > 0)
--InsIdx;
continue;
}

// Param has ByVal attribute
// Return MoveParam(param symbol).
// Ideally, the param symbol can be returned directly,
// but when SDNode builder decides to use it in a CopyToReg(),
// machine instruction fails because TargetExternalSymbol
// (not lowered) is target dependent, and CopyToReg assumes
// the source is lowered.
EVT ObjectVT = getValueType(DL, Ty);
assert(ObjectVT == Ins[InsIdx].VT &&
"Ins type did not match function type");
SDValue Arg = getParamSymbol(DAG, i, PtrVT);

SDValue P;
if (isKernelFunction(*F)) {
P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg);
P.getNode()->setIROrder(i + 1);
} else {
P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
P.getNode()->setIROrder(i + 1);
P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL,
ADDRESS_SPACE_GENERIC);
}
InVals.push_back(P);
}

if (!OutChains.empty())
DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
DAG.setRoot(DAG.getTokenFactor(dl, OutChains));

return Chain;
}
Expand Down Expand Up @@ -5784,7 +5749,7 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,

// Bitcast to i16 and unpack elements into a vector
SDLoc DL(Node);
SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
SDValue AsInt = DAG.getBitcast(MVT::i16, Op->getOperand(0));
SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
SDValue Vec1 =
Expand Down
Loading