Skip to content

Commit 4f387eb

Browse files
AlexMacleanIanWood1
authored andcommitted
[NVPTX][NFC] Refactoring and cleanup in NVPTXISelLowering (llvm#137222)
1 parent 65ef98c commit 4f387eb

File tree

1 file changed

+79
-114
lines changed

1 file changed

+79
-114
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 79 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,6 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
493493
return VectorInfo;
494494
}
495495

496-
static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
497-
SDValue Value) {
498-
if (Value->getValueType(0) == VT)
499-
return Value;
500-
return DAG.getNode(ISD::BITCAST, DL, VT, Value);
501-
}
502-
503496
// NVPTXTargetLowering Constructor.
504497
NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
505498
const NVPTXSubtarget &STI)
@@ -1587,9 +1580,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15871580

15881581
auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
15891582
SmallVector<SDValue, 6> StoreOperands;
1590-
for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
1591-
EVT EltVT = VTs[j];
1592-
int CurOffset = Offsets[j];
1583+
for (const unsigned J : llvm::seq(VTs.size())) {
1584+
EVT EltVT = VTs[J];
1585+
const int CurOffset = Offsets[J];
15931586
MaybeAlign PartAlign;
15941587
if (NeedAlign)
15951588
PartAlign = commonAlignment(ArgAlign, CurOffset);
@@ -1629,7 +1622,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16291622

16301623
// If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
16311624
// scalar store. In such cases, fall back to byte stores.
1632-
if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
1625+
if (VectorInfo[J] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
16331626
PartAlign.value() <
16341627
DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
16351628
assert(StoreOperands.empty() && "Unfinished preceeding store.");
@@ -1645,7 +1638,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16451638
}
16461639

16471640
// New store.
1648-
if (VectorInfo[j] & PVF_FIRST) {
1641+
if (VectorInfo[J] & PVF_FIRST) {
16491642
assert(StoreOperands.empty() && "Unfinished preceding store.");
16501643
StoreOperands.push_back(Chain);
16511644
StoreOperands.push_back(
@@ -1665,8 +1658,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16651658
// Record the value to store.
16661659
StoreOperands.push_back(StVal);
16671660

1668-
if (VectorInfo[j] & PVF_LAST) {
1669-
unsigned NumElts = StoreOperands.size() - 3;
1661+
if (VectorInfo[J] & PVF_LAST) {
1662+
const unsigned NumElts = StoreOperands.size() - 3;
16701663
NVPTXISD::NodeType Op;
16711664
switch (NumElts) {
16721665
case 1:
@@ -2168,7 +2161,7 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
21682161
ISD::OR, DL, MVT::i16,
21692162
{Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
21702163
EVT ToVT = Op->getValueType(0);
2171-
return MaybeBitcast(DAG, DL, ToVT, AsInt);
2164+
return DAG.getBitcast(ToVT, AsInt);
21722165
}
21732166

21742167
// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
@@ -3367,18 +3360,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33673360
auto PtrVT = getPointerTy(DAG.getDataLayout());
33683361

33693362
const Function *F = &MF.getFunction();
3370-
const AttributeList &PAL = F->getAttributes();
3371-
const TargetLowering *TLI = STI.getTargetLowering();
33723363

33733364
SDValue Root = DAG.getRoot();
3374-
std::vector<SDValue> OutChains;
3365+
SmallVector<SDValue, 16> OutChains;
33753366

3376-
std::vector<Type *> argTypes;
3377-
std::vector<const Argument *> theArgs;
3378-
for (const Argument &I : F->args()) {
3379-
theArgs.push_back(&I);
3380-
argTypes.push_back(I.getType());
3381-
}
33823367
// argTypes.size() (or theArgs.size()) and Ins.size() need not match.
33833368
// Ins.size() will be larger
33843369
// * if there is an aggregate argument with multiple fields (each field
@@ -3388,75 +3373,85 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33883373
// individually present in Ins.
33893374
// So a different index should be used for indexing into Ins.
33903375
// See similar issue in LowerCall.
3391-
unsigned InsIdx = 0;
33923376

3393-
for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) {
3394-
Type *Ty = argTypes[i];
3377+
auto AllIns = ArrayRef(Ins);
3378+
for (const auto &Arg : F->args()) {
3379+
const auto ArgIns = AllIns.take_while(
3380+
[&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
3381+
AllIns = AllIns.drop_front(ArgIns.size());
33953382

3396-
if (theArgs[i]->use_empty()) {
3397-
// argument is dead
3398-
if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
3399-
SmallVector<EVT, 16> vtparts;
3383+
Type *Ty = Arg.getType();
34003384

3401-
ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
3402-
if (vtparts.empty())
3403-
report_fatal_error("Empty parameter types are not supported");
3385+
if (ArgIns.empty())
3386+
report_fatal_error("Empty parameter types are not supported");
34043387

3405-
for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
3406-
++parti) {
3407-
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3408-
++InsIdx;
3409-
}
3410-
if (vtparts.size() > 0)
3411-
--InsIdx;
3412-
continue;
3413-
}
3414-
if (Ty->isVectorTy()) {
3415-
EVT ObjectVT = getValueType(DL, Ty);
3416-
unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
3417-
for (unsigned parti = 0; parti < NumRegs; ++parti) {
3418-
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
3419-
++InsIdx;
3420-
}
3421-
if (NumRegs > 0)
3422-
--InsIdx;
3423-
continue;
3388+
if (Arg.use_empty()) {
3389+
// argument is dead
3390+
for (const auto &In : ArgIns) {
3391+
assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
3392+
InVals.push_back(DAG.getUNDEF(In.VT));
34243393
}
3425-
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
34263394
continue;
34273395
}
34283396

3397+
SDValue ArgSymbol = getParamSymbol(DAG, Arg.getArgNo(), PtrVT);
3398+
34293399
// In the following cases, assign a node order of "i+1"
34303400
// to newly created nodes. The SDNodes for params have to
34313401
// appear in the same order as their order of appearance
34323402
// in the original function. "i+1" holds that order.
3433-
if (!PAL.hasParamAttr(i, Attribute::ByVal)) {
3403+
if (Arg.hasByValAttr()) {
3404+
// Param has ByVal attribute
3405+
// Return MoveParam(param symbol).
3406+
// Ideally, the param symbol can be returned directly,
3407+
// but when SDNode builder decides to use it in a CopyToReg(),
3408+
// machine instruction fails because TargetExternalSymbol
3409+
// (not lowered) is target dependent, and CopyToReg assumes
3410+
// the source is lowered.
3411+
assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
3412+
const auto &ByvalIn = ArgIns[0];
3413+
assert(getValueType(DL, Ty) == ByvalIn.VT &&
3414+
"Ins type did not match function type");
3415+
assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
3416+
3417+
SDValue P;
3418+
if (isKernelFunction(*F)) {
3419+
P = DAG.getNode(NVPTXISD::Wrapper, dl, ByvalIn.VT, ArgSymbol);
3420+
P.getNode()->setIROrder(Arg.getArgNo() + 1);
3421+
} else {
3422+
P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol);
3423+
P.getNode()->setIROrder(Arg.getArgNo() + 1);
3424+
P = DAG.getAddrSpaceCast(dl, ByvalIn.VT, P, ADDRESS_SPACE_LOCAL,
3425+
ADDRESS_SPACE_GENERIC);
3426+
}
3427+
InVals.push_back(P);
3428+
} else {
34343429
bool aggregateIsPacked = false;
34353430
if (StructType *STy = dyn_cast<StructType>(Ty))
34363431
aggregateIsPacked = STy->isPacked();
34373432

34383433
SmallVector<EVT, 16> VTs;
34393434
SmallVector<uint64_t, 16> Offsets;
34403435
ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
3441-
if (VTs.empty())
3442-
report_fatal_error("Empty parameter types are not supported");
3436+
assert(VTs.size() == ArgIns.size() && "Size mismatch");
3437+
assert(VTs.size() == Offsets.size() && "Size mismatch");
34433438

34443439
Align ArgAlign = getFunctionArgumentAlignment(
3445-
F, Ty, i + AttributeList::FirstArgIndex, DL);
3440+
F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
34463441
auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
3442+
assert(VectorInfo.size() == VTs.size() && "Size mismatch");
34473443

3448-
SDValue Arg = getParamSymbol(DAG, i, PtrVT);
34493444
int VecIdx = -1; // Index of the first element of the current vector.
3450-
for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
3451-
if (VectorInfo[parti] & PVF_FIRST) {
3445+
for (const unsigned PartI : llvm::seq(VTs.size())) {
3446+
if (VectorInfo[PartI] & PVF_FIRST) {
34523447
assert(VecIdx == -1 && "Orphaned vector.");
3453-
VecIdx = parti;
3448+
VecIdx = PartI;
34543449
}
34553450

34563451
// That's the last element of this store op.
3457-
if (VectorInfo[parti] & PVF_LAST) {
3458-
unsigned NumElts = parti - VecIdx + 1;
3459-
EVT EltVT = VTs[parti];
3452+
if (VectorInfo[PartI] & PVF_LAST) {
3453+
const unsigned NumElts = PartI - VecIdx + 1;
3454+
EVT EltVT = VTs[PartI];
34603455
// i1 is loaded/stored as i8.
34613456
EVT LoadVT = EltVT;
34623457
if (EltVT == MVT::i1)
@@ -3469,10 +3464,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34693464

34703465
EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
34713466
SDValue VecAddr =
3472-
DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
3467+
DAG.getNode(ISD::ADD, dl, PtrVT, ArgSymbol,
34733468
DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
3474-
Value *srcValue = Constant::getNullValue(
3475-
PointerType::get(F->getContext(), ADDRESS_SPACE_PARAM));
34763469

34773470
const MaybeAlign PartAlign = [&]() -> MaybeAlign {
34783471
if (aggregateIsPacked)
@@ -3481,23 +3474,23 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34813474
return std::nullopt;
34823475
Align PartAlign =
34833476
DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
3484-
return commonAlignment(PartAlign, Offsets[parti]);
3477+
return commonAlignment(PartAlign, Offsets[PartI]);
34853478
}();
3486-
SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
3487-
MachinePointerInfo(srcValue), PartAlign,
3488-
MachineMemOperand::MODereferenceable |
3489-
MachineMemOperand::MOInvariant);
3479+
SDValue P =
3480+
DAG.getLoad(VecVT, dl, Root, VecAddr,
3481+
MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
3482+
MachineMemOperand::MODereferenceable |
3483+
MachineMemOperand::MOInvariant);
34903484
if (P.getNode())
3491-
P.getNode()->setIROrder(i + 1);
3492-
for (unsigned j = 0; j < NumElts; ++j) {
3485+
P.getNode()->setIROrder(Arg.getArgNo() + 1);
3486+
for (const unsigned J : llvm::seq(NumElts)) {
34933487
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3494-
DAG.getIntPtrConstant(j, dl));
3488+
DAG.getIntPtrConstant(J, dl));
34953489
// We've loaded i1 as an i8 and now must truncate it back to i1
34963490
if (EltVT == MVT::i1)
34973491
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
34983492
// v2f16 was loaded as an i32. Now we must bitcast it back.
3499-
else if (EltVT != LoadVT)
3500-
Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
3493+
Elt = DAG.getBitcast(EltVT, Elt);
35013494

35023495
// If a promoted integer type is used, truncate down to the original
35033496
MVT PromotedVT;
@@ -3507,53 +3500,25 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
35073500

35083501
// Extend the element if necessary (e.g. an i8 is loaded
35093502
// into an i16 register)
3510-
if (Ins[InsIdx].VT.isInteger() &&
3511-
Ins[InsIdx].VT.getFixedSizeInBits() >
3512-
LoadVT.getFixedSizeInBits()) {
3513-
unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
3514-
: ISD::ZERO_EXTEND;
3515-
Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
3503+
if (ArgIns[PartI].VT.getFixedSizeInBits() !=
3504+
LoadVT.getFixedSizeInBits()) {
3505+
assert(ArgIns[PartI].VT.isInteger() && LoadVT.isInteger() &&
3506+
"Non-integer argument type size mismatch");
3507+
Elt = DAG.getExtOrTrunc(ArgIns[PartI].Flags.isSExt(), Elt, dl,
3508+
ArgIns[PartI].VT);
35163509
}
35173510
InVals.push_back(Elt);
35183511
}
35193512

35203513
// Reset vector tracking state.
35213514
VecIdx = -1;
35223515
}
3523-
++InsIdx;
35243516
}
3525-
if (VTs.size() > 0)
3526-
--InsIdx;
3527-
continue;
3528-
}
3529-
3530-
// Param has ByVal attribute
3531-
// Return MoveParam(param symbol).
3532-
// Ideally, the param symbol can be returned directly,
3533-
// but when SDNode builder decides to use it in a CopyToReg(),
3534-
// machine instruction fails because TargetExternalSymbol
3535-
// (not lowered) is target dependent, and CopyToReg assumes
3536-
// the source is lowered.
3537-
EVT ObjectVT = getValueType(DL, Ty);
3538-
assert(ObjectVT == Ins[InsIdx].VT &&
3539-
"Ins type did not match function type");
3540-
SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3541-
3542-
SDValue P;
3543-
if (isKernelFunction(*F)) {
3544-
P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg);
3545-
P.getNode()->setIROrder(i + 1);
3546-
} else {
3547-
P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3548-
P.getNode()->setIROrder(i + 1);
3549-
P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL,
3550-
ADDRESS_SPACE_GENERIC);
35513517
}
3552-
InVals.push_back(P);
35533518
}
35543519

35553520
if (!OutChains.empty())
3556-
DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
3521+
DAG.setRoot(DAG.getTokenFactor(dl, OutChains));
35573522

35583523
return Chain;
35593524
}
@@ -5784,7 +5749,7 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
57845749

57855750
// Bitcast to i16 and unpack elements into a vector
57865751
SDLoc DL(Node);
5787-
SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
5752+
SDValue AsInt = DAG.getBitcast(MVT::i16, Op->getOperand(0));
57885753
SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
57895754
SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
57905755
SDValue Vec1 =

0 commit comments

Comments
 (0)