Skip to content

Commit aad4cad

Browse files
committed
more cleanup
1 parent c8cc587 commit aad4cad

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+387
-1246
lines changed

clang/test/CodeGenCUDA/bf16.cu

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@ __device__ __bf16 test_call( __bf16 in) {
3737
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
3838
// CHECK: st.param.b16 [param0], %[[R]];
3939
// CHECK: .param .align 2 .b8 retval0[2];
40-
// CHECK: call.uni (retval0),
41-
// CHECK-NEXT: _Z13external_funcDF16b,
42-
// CHECK-NEXT: (
43-
// CHECK-NEXT: param0
44-
// CHECK-NEXT );
40+
// CHECK: call.uni (retval0), _Z13external_funcDF16b, (param0);
4541
// CHECK: ld.param.b16 %[[RET:rs[0-9]+]], [retval0];
4642
return external_func(in);
4743
// CHECK: st.param.b16 [func_retval0], %[[RET]]

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,25 @@ void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
457457
}
458458
llvm_unreachable("Invalid cta_group in printCTAGroup");
459459
}
460+
461+
void NVPTXInstPrinter::printCallOperand(const MCInst *MI, int OpNum,
462+
raw_ostream &O, StringRef Modifier) {
463+
const MCOperand &MO = MI->getOperand(OpNum);
464+
assert(MO.isImm() && "Invalid operand");
465+
const auto Imm = MO.getImm();
466+
467+
if (Modifier == "RetList") {
468+
assert((Imm == 1 || Imm == 0) && "Invalid return list");
469+
if (Imm)
470+
O << " (retval0),";
471+
return;
472+
}
473+
474+
if (Modifier == "ParamList") {
475+
assert(Imm >= 0 && "Invalid parameter list");
476+
interleaveComma(llvm::seq(Imm), O,
477+
[&](const auto &I) { O << "param" << I; });
478+
return;
479+
}
480+
llvm_unreachable("Invalid modifier");
481+
}

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
5252
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
5353
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O);
5454
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
55+
void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
56+
StringRef Modifier = {});
5557
};
5658

5759
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
910910
}
911911
}
912912

913-
914913
void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
915914
SDValue Src = N->getOperand(0);
916915
AddrSpaceCastSDNode *CastN = cast<AddrSpaceCastSDNode>(N);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 20 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,30 +1057,19 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10571057
MAKE_CASE(NVPTXISD::DeclareScalarParam)
10581058
MAKE_CASE(NVPTXISD::DeclareRet)
10591059
MAKE_CASE(NVPTXISD::DeclareRetParam)
1060-
MAKE_CASE(NVPTXISD::PrintCall)
1061-
MAKE_CASE(NVPTXISD::PrintConvergentCall)
1062-
MAKE_CASE(NVPTXISD::PrintCallUni)
1063-
MAKE_CASE(NVPTXISD::PrintConvergentCallUni)
1060+
MAKE_CASE(NVPTXISD::CALL)
10641061
MAKE_CASE(NVPTXISD::LoadParam)
10651062
MAKE_CASE(NVPTXISD::LoadParamV2)
10661063
MAKE_CASE(NVPTXISD::LoadParamV4)
10671064
MAKE_CASE(NVPTXISD::StoreParam)
10681065
MAKE_CASE(NVPTXISD::StoreParamV2)
10691066
MAKE_CASE(NVPTXISD::StoreParamV4)
1070-
MAKE_CASE(NVPTXISD::CallArgBegin)
1071-
MAKE_CASE(NVPTXISD::CallArg)
1072-
MAKE_CASE(NVPTXISD::LastCallArg)
1073-
MAKE_CASE(NVPTXISD::CallArgEnd)
1074-
MAKE_CASE(NVPTXISD::CallVoid)
1075-
MAKE_CASE(NVPTXISD::Prototype)
10761067
MAKE_CASE(NVPTXISD::MoveParam)
10771068
MAKE_CASE(NVPTXISD::StoreRetval)
10781069
MAKE_CASE(NVPTXISD::StoreRetvalV2)
10791070
MAKE_CASE(NVPTXISD::StoreRetvalV4)
10801071
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
10811072
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
1082-
MAKE_CASE(NVPTXISD::CallSeqBegin)
1083-
MAKE_CASE(NVPTXISD::CallSeqEnd)
10841073
MAKE_CASE(NVPTXISD::CallPrototype)
10851074
MAKE_CASE(NVPTXISD::ProxyReg)
10861075
MAKE_CASE(NVPTXISD::LoadV2)
@@ -1578,9 +1567,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15781567
? promoteScalarArgumentSize(TypeSize * 8)
15791568
: TypeSize * 8;
15801569

1581-
Chain = DAG.getNode(
1582-
NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
1583-
{Chain, GetI32(ArgI), GetI32(PromotedSize), GetI32(0), InGlue});
1570+
Chain =
1571+
DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
1572+
{Chain, GetI32(ArgI), GetI32(PromotedSize), InGlue});
15841573
}
15851574
InGlue = Chain.getValue(1);
15861575

@@ -1717,16 +1706,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
17171706
const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
17181707
if (!shouldPassAsArray(RetTy)) {
17191708
const unsigned PromotedResultSize = promoteScalarArgumentSize(ResultSize);
1720-
SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(PromotedResultSize),
1721-
GetI32(0), InGlue};
17221709
Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
1723-
DeclareRetOps);
1710+
{Chain, GetI32(PromotedResultSize), InGlue});
17241711
InGlue = Chain.getValue(1);
17251712
} else {
1726-
SDValue DeclareRetOps[] = {Chain, GetI32(RetAlign->value()),
1727-
GetI32(ResultSize / 8), GetI32(0), InGlue};
1728-
Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
1729-
{MVT::Other, MVT::Glue}, DeclareRetOps);
1713+
Chain = DAG.getNode(
1714+
NVPTXISD::DeclareRetParam, dl, {MVT::Other, MVT::Glue},
1715+
{Chain, GetI32(RetAlign->value()), GetI32(ResultSize / 8), InGlue});
17301716
InGlue = Chain.getValue(1);
17311717
}
17321718
}
@@ -1777,25 +1763,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
17771763
HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
17781764
UniqueCallSite);
17791765
const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
1780-
SDValue ProtoOps[] = {
1781-
Chain,
1782-
DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
1783-
InGlue,
1784-
};
1785-
Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
1786-
ProtoOps);
1766+
Chain = DAG.getNode(
1767+
NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
1768+
{Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InGlue});
17871769
InGlue = Chain.getValue(1);
17881770
}
1789-
// Op to just print "call"
1790-
SDValue PrintCallOps[] = {Chain, GetI32(Ins.empty() ? 0 : 1), InGlue};
1791-
// We model convergent calls as separate opcodes.
1792-
unsigned Opcode =
1793-
IsIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni;
1794-
if (CLI.IsConvergent)
1795-
Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
1796-
: NVPTXISD::PrintConvergentCall;
1797-
Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, PrintCallOps);
1798-
InGlue = Chain.getValue(1);
17991771

18001772
if (ConvertToIndirectCall) {
18011773
// Copy the function ptr to a ptx register and use the register to call the
@@ -1809,38 +1781,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
18091781
Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
18101782
}
18111783

1812-
// Ops to print out the function name
1813-
SDValue CallVoidOps[] = { Chain, Callee, InGlue };
1814-
Chain =
1815-
DAG.getNode(NVPTXISD::CallVoid, dl, {MVT::Other, MVT::Glue}, CallVoidOps);
1816-
InGlue = Chain.getValue(1);
1817-
1818-
// Ops to print out the param list
1819-
SDValue CallArgBeginOps[] = { Chain, InGlue };
1820-
Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, {MVT::Other, MVT::Glue},
1821-
CallArgBeginOps);
1784+
const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
1785+
const unsigned NumArgs =
1786+
std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
1787+
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
1788+
/// NumParams, Callee, Proto, InGlue)
1789+
Chain = DAG.getNode(NVPTXISD::CALL, dl, {MVT::Other, MVT::Glue},
1790+
{Chain, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
1791+
GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee,
1792+
GetI32(Proto), InGlue});
18221793
InGlue = Chain.getValue(1);
18231794

1824-
const unsigned E = std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
1825-
for (const unsigned I : llvm::seq(E)) {
1826-
const unsigned Opcode =
1827-
I == (E - 1) ? NVPTXISD::LastCallArg : NVPTXISD::CallArg;
1828-
SDValue CallArgOps[] = {Chain, GetI32(1), GetI32(I), InGlue};
1829-
Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, CallArgOps);
1830-
InGlue = Chain.getValue(1);
1831-
}
1832-
SDValue CallArgEndOps[] = {Chain, GetI32(IsIndirectCall ? 0 : 1), InGlue};
1833-
Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, {MVT::Other, MVT::Glue},
1834-
CallArgEndOps);
1835-
InGlue = Chain.getValue(1);
1836-
1837-
if (IsIndirectCall) {
1838-
SDValue PrototypeOps[] = {Chain, GetI32(UniqueCallSite), InGlue};
1839-
Chain = DAG.getNode(NVPTXISD::Prototype, dl, {MVT::Other, MVT::Glue},
1840-
PrototypeOps);
1841-
InGlue = Chain.getValue(1);
1842-
}
1843-
18441795
SmallVector<SDValue, 16> ProxyRegOps;
18451796
// An item of the vector is filled if the element does not need a ProxyReg
18461797
// operation on it and should be added to InVals as is. ProxyRegOps and

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,14 @@ enum NodeType : unsigned {
2929
DeclareScalarParam,
3030
DeclareRetParam,
3131
DeclareRet,
32-
PrintCall,
33-
PrintConvergentCall,
34-
PrintCallUni,
35-
PrintConvergentCallUni,
36-
CallArgBegin,
37-
CallArg,
38-
LastCallArg,
39-
CallArgEnd,
40-
CallVoid,
41-
Prototype,
32+
33+
/// This node represents a PTX call instruction. It's operands are as follows:
34+
///
35+
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
36+
/// NumParams, Callee, Proto, InGlue)
37+
CALL,
38+
4239
MoveParam,
43-
CallSeqBegin,
44-
CallSeqEnd,
4540
CallPrototype,
4641
ProxyReg,
4742
FSHL_CLAMP,

llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -190,22 +190,4 @@ unsigned NVPTXInstrInfo::insertBranch(MachineBasicBlock &MBB,
190190
BuildMI(&MBB, DL, get(NVPTX::CBranch)).add(Cond[0]).addMBB(TBB);
191191
BuildMI(&MBB, DL, get(NVPTX::GOTO)).addMBB(FBB);
192192
return 2;
193-
}
194-
195-
bool NVPTXInstrInfo::isSchedulingBoundary(const MachineInstr &MI,
196-
const MachineBasicBlock *MBB,
197-
const MachineFunction &MF) const {
198-
// Prevent the scheduler from reordering & splitting up MachineInstrs
199-
// which must stick together (in initially set order) to
200-
// comprise a valid PTX function call sequence.
201-
switch (MI.getOpcode()) {
202-
case NVPTX::CallUniPrintCallRetInst1:
203-
case NVPTX::CallArgBeginInst:
204-
case NVPTX::CallArgParam:
205-
case NVPTX::LastCallArgParam:
206-
case NVPTX::CallArgEndInst1:
207-
return true;
208-
}
209-
210-
return TargetInstrInfo::isSchedulingBoundary(MI, MBB, MF);
211-
}
193+
}

llvm/lib/Target/NVPTX/NVPTXInstrInfo.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ class NVPTXInstrInfo : public NVPTXGenInstrInfo {
6666
MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
6767
const DebugLoc &DL,
6868
int *BytesAdded = nullptr) const override;
69-
bool isSchedulingBoundary(const MachineInstr &MI,
70-
const MachineBasicBlock *MBB,
71-
const MachineFunction &MF) const override;
7269
};
7370

7471
} // namespace llvm

0 commit comments

Comments
 (0)