Skip to content

Conversation

@AlexMaclean
Copy link
Member

This change consolidates and cleans up various NVPTXISD target-specific nodes in order to simplify SDAG ISel. While there are some whitespace changes in the emitted PTX it is otherwise a non-functional change.

NVPTXISD::Wrapper - This node was used to wrap external-symbol and global-address nodes. It is redundant and has been removed. Instead we use the non-target versions of these nodes and convert them appropriately during ISel.

NVPTXISD::CALL - Much of the family of nodes used to represent a PTX call instruction have been replaced by this new single node. It corresponds to a single instruction and is therefore much simpler to create and lower.

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

This change consolidates and cleans up various NVPTXISD target-specific nodes in order to simplify SDAG ISel. While there are some whitespace changes in the emitted PTX it is otherwise a non-functional change.

NVPTXISD::Wrapper - This node was used to wrap external-symbol and global-address nodes. It is redundant and has been removed. Instead we use the non-target versions of these nodes and convert them appropriately during ISel.

NVPTXISD::CALL - Much of the family of nodes used to represent a PTX call instruction have been replaced by this new single node. It corresponds to a single instruction and is therefore much simpler to create and lower.


Patch is 171.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145581.diff

43 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+22)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+8-51)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+23-98)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+7-25)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp (+1-19)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.h (-3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+79-161)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+2)
  • (modified) llvm/test/CodeGen/NVPTX/alias.ll (+1-2)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+1-6)
  • (modified) llvm/test/CodeGen/NVPTX/byval-const-global.ll (+1-5)
  • (modified) llvm/test/CodeGen/NVPTX/call-with-alloca-buffer.ll (+1-2)
  • (modified) llvm/test/CodeGen/NVPTX/combine-mad.ll (+1-6)
  • (modified) llvm/test/CodeGen/NVPTX/convergent-mir-call.ll (+4-6)
  • (modified) llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll (+6-37)
  • (modified) llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll (+2-10)
  • (modified) llvm/test/CodeGen/NVPTX/f16-instructions.ll (+4-20)
  • (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+3-18)
  • (modified) llvm/test/CodeGen/NVPTX/fma.ll (+2-12)
  • (modified) llvm/test/CodeGen/NVPTX/forward-ld-param.ll (+2-10)
  • (modified) llvm/test/CodeGen/NVPTX/fp128-storage-type.ll (+1-5)
  • (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+71-52)
  • (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+3-18)
  • (modified) llvm/test/CodeGen/NVPTX/indirect_byval.ll (+2-14)
  • (modified) llvm/test/CodeGen/NVPTX/ldparam-v4.ll (+1-4)
  • (modified) llvm/test/CodeGen/NVPTX/local-stack-frame.ll (+6-30)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+5-31)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+3-11)
  • (modified) llvm/test/CodeGen/NVPTX/lower-byval-args.ll (+4-20)
  • (modified) llvm/test/CodeGen/NVPTX/misched_func_call.ll (+1-5)
  • (modified) llvm/test/CodeGen/NVPTX/naked-fn-with-frame-pointer.ll (+4-16)
  • (modified) llvm/test/CodeGen/NVPTX/param-add.ll (+1-5)
  • (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+56-112)
  • (modified) llvm/test/CodeGen/NVPTX/param-overalign.ll (+45-44)
  • (modified) llvm/test/CodeGen/NVPTX/param-vectorize-device.ll (+12-60)
  • (modified) llvm/test/CodeGen/NVPTX/shift-opt.ll (+2-10)
  • (modified) llvm/test/CodeGen/NVPTX/st-param-imm.ll (+84-420)
  • (modified) llvm/test/CodeGen/NVPTX/store-undef.ll (+2-10)
  • (modified) llvm/test/CodeGen/NVPTX/tex-read-cuda.ll (+1-5)
  • (modified) llvm/test/CodeGen/NVPTX/unaligned-param-load-store.ll (+7-35)
  • (modified) llvm/test/CodeGen/NVPTX/unreachable.ll (+4-16)
  • (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+4-24)
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index cc79257fb9c86..28f6968ee6caf 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -457,3 +457,25 @@ void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
   }
   llvm_unreachable("Invalid cta_group in printCTAGroup");
 }
+
+void NVPTXInstPrinter::printCallOperand(const MCInst *MI, int OpNum,
+                                        raw_ostream &O, StringRef Modifier) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  assert(MO.isImm() && "Invalid operand");
+  const auto Imm = MO.getImm();
+
+  if (Modifier == "RetList") {
+    assert((Imm == 1 || Imm == 0) && "Invalid return list");
+    if (Imm)
+      O << " (retval0),";
+    return;
+  }
+
+  if (Modifier == "ParamList") {
+    assert(Imm >= 0 && "Invalid parameter list");
+    interleaveComma(llvm::seq(Imm), O,
+                    [&](const auto &I) { O << "param" << I; });
+    return;
+  }
+  llvm_unreachable("Invalid modifier");
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index f73af7a3f2c6e..6189284e8a58c 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -52,6 +52,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
   void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O);
   void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
+  void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
+                        StringRef Modifier = {});
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ff10eea371049..af9050c55d33a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -160,8 +160,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
   case NVPTXISD::StoreParam:
   case NVPTXISD::StoreParamV2:
   case NVPTXISD::StoreParamV4:
-  case NVPTXISD::StoreParamS32:
-  case NVPTXISD::StoreParamU32:
     if (tryStoreParam(N))
       return;
     break;
@@ -909,19 +907,9 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
   switch (IID) {
   default:
     return false;
-  case Intrinsic::nvvm_texsurf_handle_internal:
-    SelectTexSurfHandle(N);
-    return true;
   }
 }
 
-void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
-  // Op 0 is the intrinsic ID
-  SDValue Wrapper = N->getOperand(1);
-  SDValue GlobalVal = Wrapper.getOperand(0);
-  ReplaceNode(N, CurDAG->getMachineNode(NVPTX::texsurf_handles, SDLoc(N),
-                                        MVT::i64, GlobalVal));
-}
 
 void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
   SDValue Src = N->getOperand(0);
@@ -1717,8 +1705,6 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
   switch (N->getOpcode()) {
   default:
     llvm_unreachable("Unexpected opcode");
-  case NVPTXISD::StoreParamU32:
-  case NVPTXISD::StoreParamS32:
   case NVPTXISD::StoreParam:
     NumElts = 1;
     break;
@@ -1796,27 +1782,6 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
     }
     }
     break;
-  // Special case: if we have a sign-extend/zero-extend node, insert the
-  // conversion instruction first, and use that as the value operand to
-  // the selected StoreParam node.
-  case NVPTXISD::StoreParamU32: {
-    Opcode = NVPTX::StoreParamI32_r;
-    SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
-                                                MVT::i32);
-    SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL,
-                                         MVT::i32, Ops[0], CvtNone);
-    Ops[0] = SDValue(Cvt, 0);
-    break;
-  }
-  case NVPTXISD::StoreParamS32: {
-    Opcode = NVPTX::StoreParamI32_r;
-    SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
-                                                MVT::i32);
-    SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL,
-                                         MVT::i32, Ops[0], CvtNone);
-    Ops[0] = SDValue(Cvt, 0);
-    break;
-  }
   }
 
   SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
@@ -2105,22 +2070,14 @@ static inline bool isAddLike(const SDValue V) {
 // selectBaseADDR - Match a dag node which will serve as the base address for an
 // ADDR operand pair.
 static SDValue selectBaseADDR(SDValue N, SelectionDAG *DAG) {
-  // Return true if TGA or ES.
-  if (N.getOpcode() == ISD::TargetGlobalAddress ||
-      N.getOpcode() == ISD::TargetExternalSymbol)
-    return N;
-
-  if (N.getOpcode() == NVPTXISD::Wrapper)
-    return N.getOperand(0);
-
-  // addrspacecast(Wrapper(arg_symbol) to addrspace(PARAM)) -> arg_symbol
-  if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N))
-    if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
-        CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
-        CastN->getOperand(0).getOpcode() == NVPTXISD::Wrapper)
-      return selectBaseADDR(CastN->getOperand(0).getOperand(0), DAG);
-
-  if (auto *FIN = dyn_cast<FrameIndexSDNode>(N))
+  if (const auto *GA = dyn_cast<GlobalAddressSDNode>(N))
+    return DAG->getTargetGlobalAddress(GA->getGlobal(), SDLoc(N),
+                                       GA->getValueType(0), GA->getOffset(),
+                                       GA->getTargetFlags());
+  if (const auto *ES = dyn_cast<ExternalSymbolSDNode>(N))
+    return DAG->getTargetExternalSymbol(ES->getSymbol(), ES->getValueType(0),
+                                        ES->getTargetFlags());
+  if (const auto *FIN = dyn_cast<FrameIndexSDNode>(N))
     return DAG->getTargetFrameIndex(FIN->getIndex(), FIN->getValueType(0));
 
   return N;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 676654d6d33e7..7e45601b00ffa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -702,9 +702,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::BR_JT, MVT::Other, Custom);
   setOperationAction(ISD::BRIND, MVT::Other, Expand);
 
-  setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
-  setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
-
   // We want to legalize constant related memmove and memcopy
   // intrinsics.
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
@@ -1054,45 +1051,24 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
   case NVPTXISD::FIRST_NUMBER:
     break;
 
-    MAKE_CASE(NVPTXISD::CALL)
     MAKE_CASE(NVPTXISD::RET_GLUE)
-    MAKE_CASE(NVPTXISD::LOAD_PARAM)
-    MAKE_CASE(NVPTXISD::Wrapper)
     MAKE_CASE(NVPTXISD::DeclareParam)
     MAKE_CASE(NVPTXISD::DeclareScalarParam)
     MAKE_CASE(NVPTXISD::DeclareRet)
-    MAKE_CASE(NVPTXISD::DeclareScalarRet)
     MAKE_CASE(NVPTXISD::DeclareRetParam)
-    MAKE_CASE(NVPTXISD::PrintCall)
-    MAKE_CASE(NVPTXISD::PrintConvergentCall)
-    MAKE_CASE(NVPTXISD::PrintCallUni)
-    MAKE_CASE(NVPTXISD::PrintConvergentCallUni)
+    MAKE_CASE(NVPTXISD::CALL)
     MAKE_CASE(NVPTXISD::LoadParam)
     MAKE_CASE(NVPTXISD::LoadParamV2)
     MAKE_CASE(NVPTXISD::LoadParamV4)
     MAKE_CASE(NVPTXISD::StoreParam)
     MAKE_CASE(NVPTXISD::StoreParamV2)
     MAKE_CASE(NVPTXISD::StoreParamV4)
-    MAKE_CASE(NVPTXISD::StoreParamS32)
-    MAKE_CASE(NVPTXISD::StoreParamU32)
-    MAKE_CASE(NVPTXISD::CallArgBegin)
-    MAKE_CASE(NVPTXISD::CallArg)
-    MAKE_CASE(NVPTXISD::LastCallArg)
-    MAKE_CASE(NVPTXISD::CallArgEnd)
-    MAKE_CASE(NVPTXISD::CallVoid)
-    MAKE_CASE(NVPTXISD::CallVal)
-    MAKE_CASE(NVPTXISD::CallSymbol)
-    MAKE_CASE(NVPTXISD::Prototype)
     MAKE_CASE(NVPTXISD::MoveParam)
     MAKE_CASE(NVPTXISD::StoreRetval)
     MAKE_CASE(NVPTXISD::StoreRetvalV2)
     MAKE_CASE(NVPTXISD::StoreRetvalV4)
-    MAKE_CASE(NVPTXISD::PseudoUseParam)
     MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
     MAKE_CASE(NVPTXISD::BUILD_VECTOR)
-    MAKE_CASE(NVPTXISD::RETURN)
-    MAKE_CASE(NVPTXISD::CallSeqBegin)
-    MAKE_CASE(NVPTXISD::CallSeqEnd)
     MAKE_CASE(NVPTXISD::CallPrototype)
     MAKE_CASE(NVPTXISD::ProxyReg)
     MAKE_CASE(NVPTXISD::LoadV2)
@@ -1114,7 +1090,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::STACKSAVE)
     MAKE_CASE(NVPTXISD::SETP_F16X2)
     MAKE_CASE(NVPTXISD::SETP_BF16X2)
-    MAKE_CASE(NVPTXISD::Dummy)
     MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
     MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
     MAKE_CASE(NVPTXISD::BrxEnd)
@@ -1188,15 +1163,6 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
   }
 }
 
-SDValue
-NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
-  SDLoc dl(Op);
-  const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op);
-  auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace());
-  Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT);
-  return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
-}
-
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *retTy, const ArgListTy &Args,
     const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
@@ -1600,9 +1566,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
               ? promoteScalarArgumentSize(TypeSize * 8)
               : TypeSize * 8;
 
-      Chain = DAG.getNode(
-          NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
-          {Chain, GetI32(ArgI), GetI32(PromotedSize), GetI32(0), InGlue});
+      Chain =
+          DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+                      {Chain, GetI32(ArgI), GetI32(PromotedSize), InGlue});
     }
     InGlue = Chain.getValue(1);
 
@@ -1739,16 +1705,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
     if (!shouldPassAsArray(RetTy)) {
       const unsigned PromotedResultSize = promoteScalarArgumentSize(ResultSize);
-      SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(PromotedResultSize),
-                                 GetI32(0), InGlue};
       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
-                          DeclareRetOps);
+                          {Chain, GetI32(PromotedResultSize), InGlue});
       InGlue = Chain.getValue(1);
     } else {
-      SDValue DeclareRetOps[] = {Chain, GetI32(RetAlign->value()),
-                                 GetI32(ResultSize / 8), GetI32(0), InGlue};
-      Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
-                          {MVT::Other, MVT::Glue}, DeclareRetOps);
+      Chain = DAG.getNode(
+          NVPTXISD::DeclareRetParam, dl, {MVT::Other, MVT::Glue},
+          {Chain, GetI32(RetAlign->value()), GetI32(ResultSize / 8), InGlue});
       InGlue = Chain.getValue(1);
     }
   }
@@ -1799,25 +1762,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                      HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
                      UniqueCallSite);
     const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
-    SDValue ProtoOps[] = {
-        Chain,
-        DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
-        InGlue,
-    };
-    Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
-                        ProtoOps);
+    Chain = DAG.getNode(
+        NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
+        {Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InGlue});
     InGlue = Chain.getValue(1);
   }
-  // Op to just print "call"
-  SDValue PrintCallOps[] = {Chain, GetI32(Ins.empty() ? 0 : 1), InGlue};
-  // We model convergent calls as separate opcodes.
-  unsigned Opcode =
-      IsIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni;
-  if (CLI.IsConvergent)
-    Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
-                                              : NVPTXISD::PrintConvergentCall;
-  Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, PrintCallOps);
-  InGlue = Chain.getValue(1);
 
   if (ConvertToIndirectCall) {
     // Copy the function ptr to a ptx register and use the register to call the
@@ -1831,38 +1780,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
   }
 
-  // Ops to print out the function name
-  SDValue CallVoidOps[] = { Chain, Callee, InGlue };
-  Chain =
-      DAG.getNode(NVPTXISD::CallVoid, dl, {MVT::Other, MVT::Glue}, CallVoidOps);
-  InGlue = Chain.getValue(1);
-
-  // Ops to print out the param list
-  SDValue CallArgBeginOps[] = { Chain, InGlue };
-  Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, {MVT::Other, MVT::Glue},
-                      CallArgBeginOps);
+  const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
+  const unsigned NumArgs =
+      std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
+  /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
+  ///      NumParams, Callee, Proto, InGlue)
+  Chain = DAG.getNode(NVPTXISD::CALL, dl, {MVT::Other, MVT::Glue},
+                      {Chain, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
+                       GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee,
+                       GetI32(Proto), InGlue});
   InGlue = Chain.getValue(1);
 
-  const unsigned E = std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
-  for (const unsigned I : llvm::seq(E)) {
-    const unsigned Opcode =
-        I == (E - 1) ? NVPTXISD::LastCallArg : NVPTXISD::CallArg;
-    SDValue CallArgOps[] = {Chain, GetI32(1), GetI32(I), InGlue};
-    Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, CallArgOps);
-    InGlue = Chain.getValue(1);
-  }
-  SDValue CallArgEndOps[] = {Chain, GetI32(IsIndirectCall ? 0 : 1), InGlue};
-  Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, {MVT::Other, MVT::Glue},
-                      CallArgEndOps);
-  InGlue = Chain.getValue(1);
-
-  if (IsIndirectCall) {
-    SDValue PrototypeOps[] = {Chain, GetI32(UniqueCallSite), InGlue};
-    Chain = DAG.getNode(NVPTXISD::Prototype, dl, {MVT::Other, MVT::Glue},
-                        PrototypeOps);
-    InGlue = Chain.getValue(1);
-  }
-
   SmallVector<SDValue, 16> ProxyRegOps;
   // An item of the vector is filled if the element does not need a ProxyReg
   // operation on it and should be added to InVals as is. ProxyRegOps and
@@ -2918,8 +2846,6 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return SDValue();
   case ISD::ADDRSPACECAST:
     return LowerADDRSPACECAST(Op, DAG);
-  case ISD::GlobalAddress:
-    return LowerGlobalAddress(Op, DAG);
   case ISD::INTRINSIC_W_CHAIN:
     return Op;
   case ISD::INTRINSIC_WO_CHAIN:
@@ -3128,8 +3054,7 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
   EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());
 
   // Store the address of unsized array <function>_vararg[] in the ap object.
-  SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
-  SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
+  SDValue VAReg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
 
   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
   return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
@@ -3369,7 +3294,7 @@ SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
                                             EVT v) const {
   StringRef SavedStr = nvTM->getStrPool().save(
       getParamName(&DAG.getMachineFunction().getFunction(), idx));
-  return DAG.getTargetExternalSymbol(SavedStr.data(), v);
+  return DAG.getExternalSymbol(SavedStr.data(), v);
 }
 
 SDValue NVPTXTargetLowering::LowerFormalArguments(
@@ -3437,7 +3362,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
       SDValue P;
       if (isKernelFunction(*F)) {
-        P = DAG.getNode(NVPTXISD::Wrapper, dl, ByvalIn.VT, ArgSymbol);
+        P = ArgSymbol;
         P.getNode()->setIROrder(Arg.getArgNo() + 1);
       } else {
         P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 0a54a8fd71f32..5efdd1582214a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -24,32 +24,19 @@ namespace NVPTXISD {
 enum NodeType : unsigned {
   // Start the numbering from where ISD NodeType finishes.
   FIRST_NUMBER = ISD::BUILTIN_OP_END,
-  Wrapper,
-  CALL,
   RET_GLUE,
-  LOAD_PARAM,
   DeclareParam,
   DeclareScalarParam,
   DeclareRetParam,
   DeclareRet,
-  DeclareScalarRet,
-  PrintCall,
-  PrintConvergentCall,
-  PrintCallUni,
-  PrintConvergentCallUni,
-  CallArgBegin,
-  CallArg,
-  LastCallArg,
-  CallArgEnd,
-  CallVoid,
-  CallVal,
-  CallSymbol,
-  Prototype,
+
+  /// This node represents a PTX call instruction. It's operands are as follows:
+  ///
+  /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
+  ///      NumParams, Callee, Proto, InGlue)
+  CALL,
+
   MoveParam,
-  PseudoUseParam,
-  RETURN,
-  CallSeqBegin,
-  CallSeqEnd,
   CallPrototype,
   ProxyReg,
   FSHL_CLAMP,
@@ -83,7 +70,6 @@ enum NodeType : unsigned {
   CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X,
   CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y,
   CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,
-  Dummy,
 
   FIRST_MEMORY_OPCODE,
   LoadV2 = FIRST_MEMORY_OPCODE,
@@ -100,8 +86,6 @@ enum NodeType : unsigned {
   StoreParam,
   StoreParamV2,
   StoreParamV4,
-  StoreParamS32, // to sext and store a <32bit value, not used currently
-  StoreParamU32, // to zext and store a <32bit value, not used currently
   StoreRetval,
   StoreRetvalV2,
   StoreRetvalV4,
@@ -120,8 +104,6 @@ class NVPTXTargetLowering : public TargetLowering {
                                const NVPTXSubtarget &STI);
   SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
 
-  SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
-
   const char *getTargetNodeName(unsigned Opcode) const override;
 
   bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index bf84d1dca4ed5..e218ef17bb09b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -190,22 +190,4 @@ unsigned NVPTXInstrInfo::insertBranch(MachineBasicBlock &MBB,
   BuildMI(&MBB, DL, get(NVPTX::CBranch)).add(Cond[0]).addMBB(TBB);
   BuildMI(&MBB, DL, get(NVPTX::GOTO)).addMBB(FBB);
   return 2;
-}
-
-bool NVPTXInstrInfo::isSchedulingBoundary(const MachineInstr &MI,
-                                          const MachineBasicBlock *MBB,
-                                          const MachineFunction &MF) const {
-  // Prevent the scheduler from reordering & splitting up MachineInstrs
-  // which must stick together (in initially set order) to
-  // comprise a valid PTX function call sequence.
-  switch (MI.getOpcode()) {
-  case NVPTX::CallUniPrintCallRetInst1:
-  case NVPTX::CallArgBeginInst:
-  case NVPTX::CallArgParam:
-  case NVPTX::LastCallArgParam:
-  case NVPTX::CallArgEndInst1:
-    return true;
-  }
-
-  return TargetInstrInfo::isSchedulingBoundary(MI, MBB, MF);
-}
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
index 95464dbbd176d..4e9dc9d3b4686 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h
@@ -66,9 +66,6 @@ class NVPTXInstrInfo : public NVPTXGenInstrInfo {
                         MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
                         const DebugLoc &DL,
                         int *BytesAdded = nullptr) const override;
-  bool isSchedulingBoundary(const MachineInstr &MI,
-                            const MachineBasicBlock *MBB,
-                            const MachineFu...
[truncated]

@github-actions
Copy link

github-actions bot commented Jun 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/cleanup-dead-nodes branch from f30694d to 6315c7d Compare June 24, 2025 20:35
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/cleanup-dead-nodes branch from 6315c7d to aad4cad Compare June 24, 2025 20:36
@llvmbot llvmbot added the clang Clang issues not falling into any other category label Jun 24, 2025
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. LGTM.

Comment on lines +474 to +479
if (Modifier == "ParamList") {
assert(Imm >= 0 && "Invalid parameter list");
interleaveComma(llvm::seq(Imm), O,
[&](const auto &I) { O << "param" << I; });
return;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if PTX imposes practical limits on the line length?
Some JIT compilations may end up generating functions with a very long list of arguments.
We may want to try folding them, or print them one per line if the total number of arguments is above certain threshold.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know of a practical limit. I know we've had some very long lines with initialization of static arrays and with debug info in the past and as far as I know it hasn't been a problem.

Should I prophylactically add some newlines or do you think it makes sense to wait and see?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it as is. JIT-produced PTX is rarely read by anyone other than ptxas. For human use current form is fine.

@AlexMaclean AlexMaclean merged commit 70333de into llvm:main Jun 25, 2025
7 checks passed
// Special case: if we have a sign-extend/zero-extend node, insert the
// conversion instruction first, and use that as the value operand to
// the selected StoreParam node.
case NVPTXISD::StoreParamU32: {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexMaclean This has removed everything but the default case (which causes MSVC build warnings) - remove the outer switch(n->getOpcode())?

RKSimon added a commit to RKSimon/llvm-project that referenced this pull request Jun 26, 2025
llvm#145581 remove all the remaining special cases from the switch statement leaving just the default, which MSVC complains about.
RKSimon added a commit that referenced this pull request Jun 26, 2025
…145948)

#145581 removed all the remaining special cases from the switch
statement leaving just the default, which MSVC complains about.
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
…5581)

This change consolidates and cleans up various NVPTXISD target-specific
nodes in order to simplify SDAG ISel. While there are some whitespace
changes in the emitted PTX it is otherwise a non-functional change.

NVPTXISD::Wrapper - This node was used to wrap external-symbol and
global-address nodes. It is redundant and has been removed. Instead we
use the non-target versions of these nodes and convert them
appropriately during ISel.

NVPTXISD::CALL - Much of the family of nodes used to represent a PTX
call instruction have been replaced by this new single node. It
corresponds to a single instruction and is therefore much simpler to
create and lower.
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
…lvm#145948)

llvm#145581 removed all the remaining special cases from the switch
statement leaving just the default, which MSVC complains about.
@ThomasRaoux
Copy link
Contributor

@AlexMaclean, this PR doesn't seem to be NFC. This generates different PTX than before. In particular I see extra mov.b32 %r, global_smem; that seem to generate different sass and causes regressions in some Triton workloads.
Before this PR I would only have one of those moves, now I see multiple. I haven't debugged why yet.

Is this is something you have noticed?

@ThomasRaoux
Copy link
Contributor

Reviving the LowerGlobalAddress piece makes the PTX match and prevents having multiple mov when the base address is used in different blocks.
I'll send a patch reintroducing those but let me know if you think there is something better to do

@AlexMaclean
Copy link
Member Author

@AlexMaclean, this PR doesn't seem to be NFC. This generates different PTX than before. In particular I see extra mov.b32 %r, global_smem; that seem to generate different sass and causes regressions in some Triton workloads. Before this PR I would only have one of those moves, now I see multiple. I haven't debugged why yet.

Is this is something you have noticed?

Sorry about this, I haven't seen this issue before. Would you be able to attach a LLVM IR reproducer? I'd be happy to investigate and see what is going on.

@ThomasRaoux
Copy link
Contributor

@AlexMaclean, this PR doesn't seem to be NFC. This generates different PTX than before. In particular I see extra mov.b32 %r, global_smem; that seem to generate different sass and causes regressions in some Triton workloads. Before this PR I would only have one of those moves, now I see multiple. I haven't debugged why yet.
Is this is something you have noticed?

Sorry about this, I haven't seen this issue before. Would you be able to attach a LLVM IR reproducer? I'd be happy to investigate and see what is going on.

here is a reproducer: https://gist.github.com/ThomasRaoux/e5dbdb35a418d246ac4393a802902621
sorry it is a bit big but if you compile before and after your patch you can see there used to be one mov for global_smem and now there are two including one in the loop.

I would have thought it would be the same for ptxas but looks like it is not the case :(

I have a partial revert that solves the problem, once I have a test I'll send a PR

@ThomasRaoux
Copy link
Contributor

actually, doing a partial revert may be harder than I thought. @AlexMaclean, would you be okay if I revert this PR while you figure out why this is not NFC? That would unblock our LLVM upgrade on Triton side

@ThomasRaoux
Copy link
Contributor

never mind, I don't think reverting is an option as it cannot be done cleanly

@AlexMaclean
Copy link
Member Author

I'm looking at your reproducer now, I'll let you know as soon as I have a fix.

@AlexMaclean
Copy link
Member Author

The problem seems to be that we're now reusing the MOV_B64_i instruction to move the address of the global into a register. This instruction is marked as isAsCheapAsAMove = true so we no longer bother to do CSE on it. This doesn't necessarily seem like a problem or incorrect so I'm hesitant to "fix" it by re-introducing a non-cheap mov instruction for global-addresses. We've perturbed PTX a little bit and that can sometimes cause both regressions and improvements.

@ThomasRaoux have you experimented with using maxnreg or --maxrregcount to help PTXAS out here? If this kernel doesn't have a register target, this might be the sort of thing that could change the compiler's guess about what it should be.

@ThomasRaoux
Copy link
Contributor

ThomasRaoux commented Aug 14, 2025

The problem seems to be that we're now reusing the MOV_B64_i instruction to move the address of the global into a register. This instruction is marked as isAsCheapAsAMove = true so we no longer bother to do CSE on it. This doesn't necessarily seem like a problem or incorrect so I'm hesitant to "fix" it by re-introducing a non-cheap mov instruction for global-addresses. We've perturbed PTX a little bit and that can sometimes cause both regressions and improvements.

@ThomasRaoux have you experimented with using maxnreg or --maxrregcount to help PTXAS out here? If this kernel doesn't have a register target, this might be the sort of thing that could change the compiler's guess about what it should be.

Looking at the sass it doesn't use extra registers. I see extra arithmetic in the loop. I need to take a ncu trace to understand why it makes a significant difference but it might just be extra arithmetic and worse scheduling.
If ptxas doesn't treat this move as a no-op CSEing it would be nice. I'll check if I can find a workaround otherwise I'm not sure how to unblock this as the performance drop will be blocking our LLVM upgrade.

I'll let you know if ncu shows anything

@ThomasRaoux
Copy link
Contributor

ThomasRaoux commented Aug 15, 2025

@AlexMaclean I compared the runs in ncu and there are no differences in occupancy and the arithmetic usage is roughly the same but I see some large stalls on SR_CgaCtaId read in a loop that comes from the extra global_smem copy:
image

This seem to be the main reason for the significant slow down here. It seems like a legit problem from what ptxas generates and I don't think it can be workaround from user point of view.
Can we go back to doing CSE for global_smem move as this seems to help code quality

image image

@AlexMaclean
Copy link
Member Author

@ThomasRaoux thanks for the investigation. Please confirm if #153730 fixes the issue!

@ThomasRaoux
Copy link
Contributor

@ThomasRaoux thanks for the investigation. Please confirm if #153730 fixes the issue!

it does! Thank you so much

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:NVPTX clang Clang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants