-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[NVPTX] Consolidate and cleanup various NVPTXISD nodes (NFC) #145581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Consolidate and cleanup various NVPTXISD nodes (NFC) #145581
Conversation
|
@llvm/pr-subscribers-clang @llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesThis change consolidates and cleans up various NVPTXISD target-specific nodes in order to simplify SDAG ISel. While there are some whitespace changes in the emitted PTX it is otherwise a non-functional change. NVPTXISD::Wrapper - This node was used to wrap external-symbol and global-address nodes. It is redundant and has been removed. Instead we use the non-target versions of these nodes and convert them appropriately during ISel. NVPTXISD::CALL - Much of the family of nodes used to represent a PTX call instruction have been replaced by this new single node. It corresponds to a single instruction and is therefore much simpler to create and lower. Patch is 171.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145581.diff 43 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index cc79257fb9c86..28f6968ee6caf 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -457,3 +457,25 @@ void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
}
llvm_unreachable("Invalid cta_group in printCTAGroup");
}
+
+void NVPTXInstPrinter::printCallOperand(const MCInst *MI, int OpNum,
+ raw_ostream &O, StringRef Modifier) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ assert(MO.isImm() && "Invalid operand");
+ const auto Imm = MO.getImm();
+
+ if (Modifier == "RetList") {
+ assert((Imm == 1 || Imm == 0) && "Invalid return list");
+ if (Imm)
+ O << " (retval0),";
+ return;
+ }
+
+ if (Modifier == "ParamList") {
+ assert(Imm >= 0 && "Invalid parameter list");
+ interleaveComma(llvm::seq(Imm), O,
+ [&](const auto &I) { O << "param" << I; });
+ return;
+ }
+ llvm_unreachable("Invalid modifier");
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index f73af7a3f2c6e..6189284e8a58c 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -52,6 +52,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O);
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
+ void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
+ StringRef Modifier = {});
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ff10eea371049..af9050c55d33a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -160,8 +160,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
- case NVPTXISD::StoreParamS32:
- case NVPTXISD::StoreParamU32:
if (tryStoreParam(N))
return;
break;
@@ -909,19 +907,9 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
switch (IID) {
default:
return false;
- case Intrinsic::nvvm_texsurf_handle_internal:
- SelectTexSurfHandle(N);
- return true;
}
}
-void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
- // Op 0 is the intrinsic ID
- SDValue Wrapper = N->getOperand(1);
- SDValue GlobalVal = Wrapper.getOperand(0);
- ReplaceNode(N, CurDAG->getMachineNode(NVPTX::texsurf_handles, SDLoc(N),
- MVT::i64, GlobalVal));
-}
void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
SDValue Src = N->getOperand(0);
@@ -1717,8 +1705,6 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
switch (N->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
- case NVPTXISD::StoreParamU32:
- case NVPTXISD::StoreParamS32:
case NVPTXISD::StoreParam:
NumElts = 1;
break;
@@ -1796,27 +1782,6 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
}
}
break;
- // Special case: if we have a sign-extend/zero-extend node, insert the
- // conversion instruction first, and use that as the value operand to
- // the selected StoreParam node.
- case NVPTXISD::StoreParamU32: {
- Opcode = NVPTX::StoreParamI32_r;
- SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
- MVT::i32);
- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL,
- MVT::i32, Ops[0], CvtNone);
- Ops[0] = SDValue(Cvt, 0);
- break;
- }
- case NVPTXISD::StoreParamS32: {
- Opcode = NVPTX::StoreParamI32_r;
- SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
- MVT::i32);
- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL,
- MVT::i32, Ops[0], CvtNone);
- Ops[0] = SDValue(Cvt, 0);
- break;
- }
}
SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
@@ -2105,22 +2070,14 @@ static inline bool isAddLike(const SDValue V) {
// selectBaseADDR - Match a dag node which will serve as the base address for an
// ADDR operand pair.
static SDValue selectBaseADDR(SDValue N, SelectionDAG *DAG) {
- // Return true if TGA or ES.
- if (N.getOpcode() == ISD::TargetGlobalAddress ||
- N.getOpcode() == ISD::TargetExternalSymbol)
- return N;
-
- if (N.getOpcode() == NVPTXISD::Wrapper)
- return N.getOperand(0);
-
- // addrspacecast(Wrapper(arg_symbol) to addrspace(PARAM)) -> arg_symbol
- if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N))
- if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
- CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
- CastN->getOperand(0).getOpcode() == NVPTXISD::Wrapper)
- return selectBaseADDR(CastN->getOperand(0).getOperand(0), DAG);
-
- if (auto *FIN = dyn_cast<FrameIndexSDNode>(N))
+ if (const auto *GA = dyn_cast<GlobalAddressSDNode>(N))
+ return DAG->getTargetGlobalAddress(GA->getGlobal(), SDLoc(N),
+ GA->getValueType(0), GA->getOffset(),
+ GA->getTargetFlags());
+ if (const auto *ES = dyn_cast<ExternalSymbolSDNode>(N))
+ return DAG->getTargetExternalSymbol(ES->getSymbol(), ES->getValueType(0),
+ ES->getTargetFlags());
+ if (const auto *FIN = dyn_cast<FrameIndexSDNode>(N))
return DAG->getTargetFrameIndex(FIN->getIndex(), FIN->getValueType(0));
return N;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 676654d6d33e7..7e45601b00ffa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -702,9 +702,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
setOperationAction(ISD::BRIND, MVT::Other, Expand);
- setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
- setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
-
// We want to legalize constant related memmove and memcopy
// intrinsics.
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
@@ -1054,45 +1051,24 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
case NVPTXISD::FIRST_NUMBER:
break;
- MAKE_CASE(NVPTXISD::CALL)
MAKE_CASE(NVPTXISD::RET_GLUE)
- MAKE_CASE(NVPTXISD::LOAD_PARAM)
- MAKE_CASE(NVPTXISD::Wrapper)
MAKE_CASE(NVPTXISD::DeclareParam)
MAKE_CASE(NVPTXISD::DeclareScalarParam)
MAKE_CASE(NVPTXISD::DeclareRet)
- MAKE_CASE(NVPTXISD::DeclareScalarRet)
MAKE_CASE(NVPTXISD::DeclareRetParam)
- MAKE_CASE(NVPTXISD::PrintCall)
- MAKE_CASE(NVPTXISD::PrintConvergentCall)
- MAKE_CASE(NVPTXISD::PrintCallUni)
- MAKE_CASE(NVPTXISD::PrintConvergentCallUni)
+ MAKE_CASE(NVPTXISD::CALL)
MAKE_CASE(NVPTXISD::LoadParam)
MAKE_CASE(NVPTXISD::LoadParamV2)
MAKE_CASE(NVPTXISD::LoadParamV4)
MAKE_CASE(NVPTXISD::StoreParam)
MAKE_CASE(NVPTXISD::StoreParamV2)
MAKE_CASE(NVPTXISD::StoreParamV4)
- MAKE_CASE(NVPTXISD::StoreParamS32)
- MAKE_CASE(NVPTXISD::StoreParamU32)
- MAKE_CASE(NVPTXISD::CallArgBegin)
- MAKE_CASE(NVPTXISD::CallArg)
- MAKE_CASE(NVPTXISD::LastCallArg)
- MAKE_CASE(NVPTXISD::CallArgEnd)
- MAKE_CASE(NVPTXISD::CallVoid)
- MAKE_CASE(NVPTXISD::CallVal)
- MAKE_CASE(NVPTXISD::CallSymbol)
- MAKE_CASE(NVPTXISD::Prototype)
MAKE_CASE(NVPTXISD::MoveParam)
MAKE_CASE(NVPTXISD::StoreRetval)
MAKE_CASE(NVPTXISD::StoreRetvalV2)
MAKE_CASE(NVPTXISD::StoreRetvalV4)
- MAKE_CASE(NVPTXISD::PseudoUseParam)
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
- MAKE_CASE(NVPTXISD::RETURN)
- MAKE_CASE(NVPTXISD::CallSeqBegin)
- MAKE_CASE(NVPTXISD::CallSeqEnd)
MAKE_CASE(NVPTXISD::CallPrototype)
MAKE_CASE(NVPTXISD::ProxyReg)
MAKE_CASE(NVPTXISD::LoadV2)
@@ -1114,7 +1090,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::STACKSAVE)
MAKE_CASE(NVPTXISD::SETP_F16X2)
MAKE_CASE(NVPTXISD::SETP_BF16X2)
- MAKE_CASE(NVPTXISD::Dummy)
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
MAKE_CASE(NVPTXISD::BrxEnd)
@@ -1188,15 +1163,6 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
}
}
-SDValue
-NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
- SDLoc dl(Op);
- const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op);
- auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace());
- Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT);
- return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
-}
-
std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
@@ -1600,9 +1566,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
? promoteScalarArgumentSize(TypeSize * 8)
: TypeSize * 8;
- Chain = DAG.getNode(
- NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(ArgI), GetI32(PromotedSize), GetI32(0), InGlue});
+ Chain =
+ DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+ {Chain, GetI32(ArgI), GetI32(PromotedSize), InGlue});
}
InGlue = Chain.getValue(1);
@@ -1739,16 +1705,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
if (!shouldPassAsArray(RetTy)) {
const unsigned PromotedResultSize = promoteScalarArgumentSize(ResultSize);
- SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(PromotedResultSize),
- GetI32(0), InGlue};
Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
- DeclareRetOps);
+ {Chain, GetI32(PromotedResultSize), InGlue});
InGlue = Chain.getValue(1);
} else {
- SDValue DeclareRetOps[] = {Chain, GetI32(RetAlign->value()),
- GetI32(ResultSize / 8), GetI32(0), InGlue};
- Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
- {MVT::Other, MVT::Glue}, DeclareRetOps);
+ Chain = DAG.getNode(
+ NVPTXISD::DeclareRetParam, dl, {MVT::Other, MVT::Glue},
+ {Chain, GetI32(RetAlign->value()), GetI32(ResultSize / 8), InGlue});
InGlue = Chain.getValue(1);
}
}
@@ -1799,25 +1762,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
UniqueCallSite);
const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
- SDValue ProtoOps[] = {
- Chain,
- DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
- InGlue,
- };
- Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
- ProtoOps);
+ Chain = DAG.getNode(
+ NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
+ {Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InGlue});
InGlue = Chain.getValue(1);
}
- // Op to just print "call"
- SDValue PrintCallOps[] = {Chain, GetI32(Ins.empty() ? 0 : 1), InGlue};
- // We model convergent calls as separate opcodes.
- unsigned Opcode =
- IsIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni;
- if (CLI.IsConvergent)
- Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
- : NVPTXISD::PrintConvergentCall;
- Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, PrintCallOps);
- InGlue = Chain.getValue(1);
if (ConvertToIndirectCall) {
// Copy the function ptr to a ptx register and use the register to call the
@@ -1831,38 +1780,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
}
- // Ops to print out the function name
- SDValue CallVoidOps[] = { Chain, Callee, InGlue };
- Chain =
- DAG.getNode(NVPTXISD::CallVoid, dl, {MVT::Other, MVT::Glue}, CallVoidOps);
- InGlue = Chain.getValue(1);
-
- // Ops to print out the param list
- SDValue CallArgBeginOps[] = { Chain, InGlue };
- Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, {MVT::Other, MVT::Glue},
- CallArgBeginOps);
+ const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
+ const unsigned NumArgs =
+ std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
+ /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
+ /// NumParams, Callee, Proto, InGlue)
+ Chain = DAG.getNode(NVPTXISD::CALL, dl, {MVT::Other, MVT::Glue},
+ {Chain, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
+ GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee,
+ GetI32(Proto), InGlue});
InGlue = Chain.getValue(1);
- const unsigned E = std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
- for (const unsigned I : llvm::seq(E)) {
- const unsigned Opcode =
- I == (E - 1) ? NVPTXISD::LastCallArg : NVPTXISD::CallArg;
- SDValue CallArgOps[] = {Chain, GetI32(1), GetI32(I), InGlue};
- Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, CallArgOps);
- InGlue = Chain.getValue(1);
- }
- SDValue CallArgEndOps[] = {Chain, GetI32(IsIndirectCall ? 0 : 1), InGlue};
- Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, {MVT::Other, MVT::Glue},
- CallArgEndOps);
- InGlue = Chain.getValue(1);
-
- if (IsIndirectCall) {
- SDValue PrototypeOps[] = {Chain, GetI32(UniqueCallSite), InGlue};
- Chain = DAG.getNode(NVPTXISD::Prototype, dl, {MVT::Other, MVT::Glue},
- PrototypeOps);
- InGlue = Chain.getValue(1);
- }
-
SmallVector<SDValue, 16> ProxyRegOps;
// An item of the vector is filled if the element does not need a ProxyReg
// operation on it and should be added to InVals as is. ProxyRegOps and
@@ -2918,8 +2846,6 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return SDValue();
case ISD::ADDRSPACECAST:
return LowerADDRSPACECAST(Op, DAG);
- case ISD::GlobalAddress:
- return LowerGlobalAddress(Op, DAG);
case ISD::INTRINSIC_W_CHAIN:
return Op;
case ISD::INTRINSIC_WO_CHAIN:
@@ -3128,8 +3054,7 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
// Store the address of unsized array <function>_vararg[] in the ap object.
- SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
- SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
+ SDValue VAReg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
@@ -3369,7 +3294,7 @@ SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
EVT v) const {
StringRef SavedStr = nvTM->getStrPool().save(
getParamName(&DAG.getMachineFunction().getFunction(), idx));
- return DAG.getTargetExternalSymbol(SavedStr.data(), v);
+ return DAG.getExternalSymbol(SavedStr.data(), v);
}
SDValue NVPTXTargetLowering::LowerFormalArguments(
@@ -3437,7 +3362,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
SDValue P;
if (isKernelFunction(*F)) {
- P = DAG.getNode(NVPTXISD::Wrapper, dl, ByvalIn.VT, ArgSymbol);
+ P = ArgSymbol;
P.getNode()->setIROrder(Arg.getArgNo() + 1);
} else {
P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 0a54a8fd71f32..5efdd1582214a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -24,32 +24,19 @@ namespace NVPTXISD {
enum NodeType : unsigned {
// Start the numbering from where ISD NodeType finishes.
FIRST_NUMBER = ISD::BUILTIN_OP_END,
- Wrapper,
- CALL,
RET_GLUE,
- LOAD_PARAM,
DeclareParam,
DeclareScalarParam,
DeclareRetParam,
DeclareRet,
- DeclareScalarRet,
- PrintCall,
- PrintConvergentCall,
- PrintCallUni,
- PrintConvergentCallUni,
- CallArgBegin,
- CallArg,
- LastCallArg,
- CallArgEnd,
- CallVoid,
- CallVal,
- CallSymbol,
- Prototype,
+
+ /// This node represents a PTX call instruction. It's operands are as follows:
+ ///
+ /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
+ /// NumParams, Callee, Proto, InGlue)
+ CALL,
+
MoveParam,
- PseudoUseParam,
- RETURN,
- CallSeqBegin,
- CallSeqEnd,
CallPrototype,
ProxyReg,
FSHL_CLAMP,
@@ -83,7 +70,6 @@ enum NodeType : unsigned {
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X,
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y,
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,
- Dummy,
FIRST_MEMORY_OPCODE,
LoadV2 = FIRST_MEMORY_OPCODE,
@@ -100,8 +86,6 @@ enum NodeType : unsigned {
StoreParam,
StoreParamV2,
StoreParamV4,
- StoreParamS32, // to sext and store a <32bit value, not used currently
- StoreParamU32, // to zext and store a <32bit value, not used currently
StoreRetval,
StoreRetvalV2,
StoreRetvalV4,
@@ -120,8 +104,6 @@ class NVPTXTargetLowering : public TargetLowering {
const NVPTXSubtarget &STI);
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
- SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
-
const char *getTargetNodeName(unsigned Opcode) const override;
bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index bf84d1dca4ed5..e218ef17bb09b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -190,22 +190,4 @@ unsigned NVPTXInstrInfo::insertBranch(MachineBasicBlock &MBB,
BuildMI(&MBB, DL, get(NVPTX::CBranch)).add(Cond[0]).addMBB(TBB);
BuildMI(&MBB, DL, get(NVPTX::GOTO)).addMBB(FBB);
return 2;
-}
-
-bool NVPTXInstrInfo::isSchedulingBoundary(const MachineInstr &MI,
- const MachineBasicBlock *MBB,
- const MachineFunction &MF) const {
- // Prevent the scheduler from reordering & splitting up MachineInstrs
- // which must stick together (in initially set order) to
- // comprise a valid PTX function call sequence.
- switch (MI.getOpcode()) {
- case NVPTX::CallUniPrintCallRetInst1:
- case NVPTX::CallArgBeginInst:
- case NVPTX::CallArgParam:
- case NVPTX::LastCallArgParam:
- case NVPTX::CallArgEndInst1:
- return true;
- }
-
- return TargetInstrInfo::isSchedulingBoundary(MI, MBB, MF);
-}
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
index 95464dbbd176d..4e9dc9d3b4686 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
@@ -66,9 +66,6 @@ class NVPTXInstrInfo : public NVPTXGenInstrInfo {
MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
const DebugLoc &DL,
int *BytesAdded = nullptr) const override;
- bool isSchedulingBoundary(const MachineInstr &MI,
- const MachineBasicBlock *MBB,
- const MachineFu...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
f30694d to
6315c7d
Compare
6315c7d to
aad4cad
Compare
Artem-B
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. LGTM.
| if (Modifier == "ParamList") { | ||
| assert(Imm >= 0 && "Invalid parameter list"); | ||
| interleaveComma(llvm::seq(Imm), O, | ||
| [&](const auto &I) { O << "param" << I; }); | ||
| return; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if PTX imposes practical limits on the line length?
Some JIT compilations may end up generating functions with a very long list of arguments.
We may want to try folding them, or print them one per line if the total number of arguments is above certain threshold.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know of a practical limit. I know we've had some very long lines with initialization of static arrays and with debug info in the past and as far as I know it hasn't been a problem.
Should I prophylactically add some newlines or do you think it makes sense to wait and see?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it as is. JIT-produced PTX is rarely read by anyone other than ptxas. For human use current form is fine.
| // Special case: if we have a sign-extend/zero-extend node, insert the | ||
| // conversion instruction first, and use that as the value operand to | ||
| // the selected StoreParam node. | ||
| case NVPTXISD::StoreParamU32: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AlexMaclean This has removed everything but the default case (which causes MSVC build warnings) - remove the outer switch(n->getOpcode())?
llvm#145581 remove all the remaining special cases from the switch statement leaving just the default, which MSVC complains about.
…5581) This change consolidates and cleans up various NVPTXISD target-specific nodes in order to simplify SDAG ISel. While there are some whitespace changes in the emitted PTX it is otherwise a non-functional change. NVPTXISD::Wrapper - This node was used to wrap external-symbol and global-address nodes. It is redundant and has been removed. Instead we use the non-target versions of these nodes and convert them appropriately during ISel. NVPTXISD::CALL - Much of the family of nodes used to represent a PTX call instruction have been replaced by this new single node. It corresponds to a single instruction and is therefore much simpler to create and lower.
…lvm#145948) llvm#145581 removed all the remaining special cases from the switch statement leaving just the default, which MSVC complains about.
|
@AlexMaclean, this PR doesn't seem to be NFC. This generates different PTX than before. In particular I see extra Is this is something you have noticed? |
|
Reviving the |
Sorry about this, I haven't seen this issue before. Would you be able to attach a LLVM IR reproducer? I'd be happy to investigate and see what is going on. |
here is a reproducer: https://gist.github.com/ThomasRaoux/e5dbdb35a418d246ac4393a802902621 I would have thought it would be the same for ptxas but looks like it is not the case :( I have a partial revert that solves the problem, once I have a test I'll send a PR |
|
actually, doing a partial revert may be harder than I thought. @AlexMaclean, would you be okay if I revert this PR while you figure out why this is not NFC? That would unblock our LLVM upgrade on Triton side |
|
never mind, I don't think reverting is an option as it cannot be done cleanly |
|
I'm looking at your reproducer now, I'll let you know as soon as I have a fix. |
|
The problem seems to be that we're now reusing the @ThomasRaoux have you experimented with using maxnreg or --maxrregcount to help PTXAS out here? If this kernel doesn't have a register target, this might be the sort of thing that could change the compiler's guess about what it should be. |
Looking at the sass it doesn't use extra registers. I see extra arithmetic in the loop. I need to take a ncu trace to understand why it makes a significant difference but it might just be extra arithmetic and worse scheduling. I'll let you know if ncu shows anything |
|
@AlexMaclean I compared the runs in ncu and there are no differences in occupancy and the arithmetic usage is roughly the same but I see some large stalls on This seem to be the main reason for the significant slow down here. It seems like a legit problem from what ptxas generates and I don't think it can be workaround from user point of view.
|
|
@ThomasRaoux thanks for the investigation. Please confirm if #153730 fixes the issue! |
it does! Thank you so much |



This change consolidates and cleans up various NVPTXISD target-specific nodes in order to simplify SDAG ISel. While there are some whitespace changes in the emitted PTX it is otherwise a non-functional change.
NVPTXISD::Wrapper - This node was used to wrap external-symbol and global-address nodes. It is redundant and has been removed. Instead we use the non-target versions of these nodes and convert them appropriately during ISel.
NVPTXISD::CALL - Much of the family of nodes used to represent a PTX call instruction have been replaced by this new single node. It corresponds to a single instruction and is therefore much simpler to create and lower.