Skip to content

Conversation

@AlexMaclean
Copy link
Member

This change fixes v2i8 lowering for parameters and returned values. As part of this work, I move the lowering for return values to use generic ISD::STORE nodes as these are more flexible and have existing legalization handling.

Note that calling a function with v2i8 arguments or returns is still not working but this is left for a subsequent change as this MR is already fairly large.

Partially addresses #128853

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

This change fixes v2i8 lowering for parameters and returned values. As part of this work, I move the lowering for return values to use generic ISD::STORE nodes as these are more flexible and have existing legalization handling.

Note that calling a function with v2i8 arguments or returns is still not working but this is left for a subsequent change as this MR is already fairly large.

Partially addresses #128853


Patch is 158.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145585.diff

37 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (-84)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+52-112)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+1-4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (-45)
  • (modified) llvm/test/CodeGen/NVPTX/and-or-setcc.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/atomics.ll (+2-3)
  • (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+64-72)
  • (modified) llvm/test/CodeGen/NVPTX/compute-ptx-value-vts.ll (+22-21)
  • (modified) llvm/test/CodeGen/NVPTX/convert-fp-i8.ll (+4-6)
  • (modified) llvm/test/CodeGen/NVPTX/convert-int-sm20.ll (+49-24)
  • (modified) llvm/test/CodeGen/NVPTX/elect.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/extractelement.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/f16-instructions.ll (+14-14)
  • (modified) llvm/test/CodeGen/NVPTX/fexp2.ll (+6-8)
  • (modified) llvm/test/CodeGen/NVPTX/flog2.ll (+6-8)
  • (modified) llvm/test/CodeGen/NVPTX/fma-relu-contract.ll (+7-8)
  • (modified) llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll (+6-8)
  • (modified) llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll (+13-16)
  • (modified) llvm/test/CodeGen/NVPTX/fma.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+4-6)
  • (modified) llvm/test/CodeGen/NVPTX/i8x2-instructions.ll (+7-9)
  • (modified) llvm/test/CodeGen/NVPTX/idioms.ll (+100-31)
  • (modified) llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll (+4-6)
  • (modified) llvm/test/CodeGen/NVPTX/ldg-invariant.ll (+4-6)
  • (modified) llvm/test/CodeGen/NVPTX/ldu-i8.ll (+13-4)
  • (modified) llvm/test/CodeGen/NVPTX/ldu-ldg.ll (+2-3)
  • (modified) llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll (+3-3)
  • (modified) llvm/test/CodeGen/NVPTX/param-add.ll (+11-20)
  • (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+13-11)
  • (modified) llvm/test/CodeGen/NVPTX/param-vectorize-device.ll (+19-19)
  • (modified) llvm/test/CodeGen/NVPTX/proxy-reg-erasure-ptx.ll (+6-10)
  • (modified) llvm/test/CodeGen/NVPTX/shift-opt.ll (+6-5)
  • (modified) llvm/test/CodeGen/NVPTX/tid-range.ll (+12-5)
  • (modified) llvm/test/CodeGen/NVPTX/unaligned-param-load-store.ll (+488-337)
  • (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+3-6)
  • (modified) llvm/test/CodeGen/NVPTX/vector-returns.ll (+88-70)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ff10eea371049..8f26d235279b8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -151,12 +151,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     if (tryLoadParam(N))
       return;
     break;
-  case NVPTXISD::StoreRetval:
-  case NVPTXISD::StoreRetvalV2:
-  case NVPTXISD::StoreRetvalV4:
-    if (tryStoreRetval(N))
-      return;
-    break;
   case NVPTXISD::StoreParam:
   case NVPTXISD::StoreParamV2:
   case NVPTXISD::StoreParamV4:
@@ -1530,84 +1524,6 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
   return true;
 }
 
-bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
-  SDLoc DL(N);
-  SDValue Chain = N->getOperand(0);
-  SDValue Offset = N->getOperand(1);
-  unsigned OffsetVal = Offset->getAsZExtVal();
-  MemSDNode *Mem = cast<MemSDNode>(N);
-
-  // How many elements do we have?
-  unsigned NumElts = 1;
-  switch (N->getOpcode()) {
-  default:
-    return false;
-  case NVPTXISD::StoreRetval:
-    NumElts = 1;
-    break;
-  case NVPTXISD::StoreRetvalV2:
-    NumElts = 2;
-    break;
-  case NVPTXISD::StoreRetvalV4:
-    NumElts = 4;
-    break;
-  }
-
-  // Build vector of operands
-  SmallVector<SDValue, 6> Ops;
-  for (unsigned i = 0; i < NumElts; ++i)
-    Ops.push_back(N->getOperand(i + 2));
-  Ops.append({CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain});
-
-  // Determine target opcode
-  // If we have an i1, use an 8-bit store. The lowering code in
-  // NVPTXISelLowering will have already emitted an upcast.
-  std::optional<unsigned> Opcode = 0;
-  switch (NumElts) {
-  default:
-    return false;
-  case 1:
-    Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
-                             NVPTX::StoreRetvalI8, NVPTX::StoreRetvalI16,
-                             NVPTX::StoreRetvalI32, NVPTX::StoreRetvalI64);
-    if (Opcode == NVPTX::StoreRetvalI8) {
-      // Fine tune the opcode depending on the size of the operand.
-      // This helps to avoid creating redundant COPY instructions in
-      // InstrEmitter::AddRegisterOperand().
-      switch (Ops[0].getSimpleValueType().SimpleTy) {
-      default:
-        break;
-      case MVT::i32:
-        Opcode = NVPTX::StoreRetvalI8TruncI32;
-        break;
-      case MVT::i64:
-        Opcode = NVPTX::StoreRetvalI8TruncI64;
-        break;
-      }
-    }
-    break;
-  case 2:
-    Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
-                             NVPTX::StoreRetvalV2I8, NVPTX::StoreRetvalV2I16,
-                             NVPTX::StoreRetvalV2I32, NVPTX::StoreRetvalV2I64);
-    break;
-  case 4:
-    Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
-                             NVPTX::StoreRetvalV4I8, NVPTX::StoreRetvalV4I16,
-                             NVPTX::StoreRetvalV4I32, {/* no v4i64 */});
-    break;
-  }
-  if (!Opcode)
-    return false;
-
-  SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops);
-  MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
-  CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
-
-  ReplaceNode(N, Ret);
-  return true;
-}
-
 // Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
 #define getOpcV2H(ty, opKind0, opKind1)                                        \
   NVPTX::StoreParamV2##ty##_##opKind0##opKind1
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index ff58e4486a222..79db4edae6bc1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -80,7 +80,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool tryStore(SDNode *N);
   bool tryStoreVector(SDNode *N);
   bool tryLoadParam(SDNode *N);
-  bool tryStoreRetval(SDNode *N);
   bool tryStoreParam(SDNode *N);
   bool tryFence(SDNode *N);
   void SelectAddrSpaceCast(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d2fafe854e9e4..98ee3ab6b112c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -370,7 +370,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
       } else if (EltVT.getSimpleVT() == MVT::i8 && NumElts == 2) {
         // v2i8 is promoted to v2i16
         NumElts = 1;
-        EltVT = MVT::v2i16;
+        EltVT = MVT::v2i8;
       }
       for (unsigned j = 0; j != NumElts; ++j) {
         ValueVTs.push_back(EltVT);
@@ -1085,9 +1085,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::CallSymbol)
     MAKE_CASE(NVPTXISD::Prototype)
     MAKE_CASE(NVPTXISD::MoveParam)
-    MAKE_CASE(NVPTXISD::StoreRetval)
-    MAKE_CASE(NVPTXISD::StoreRetvalV2)
-    MAKE_CASE(NVPTXISD::StoreRetvalV4)
     MAKE_CASE(NVPTXISD::PseudoUseParam)
     MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
     MAKE_CASE(NVPTXISD::BUILD_VECTOR)
@@ -1472,7 +1469,11 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
 }
 
 static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
-  return Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+  if (Flags.isSExt())
+    return ISD::SIGN_EXTEND;
+  if (Flags.isZExt())
+    return ISD::ZERO_EXTEND;
+  return ISD::ANY_EXTEND;
 }
 
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
@@ -3448,10 +3449,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       }
       InVals.push_back(P);
     } else {
-      bool aggregateIsPacked = false;
-      if (StructType *STy = dyn_cast<StructType>(Ty))
-        aggregateIsPacked = STy->isPacked();
-
       SmallVector<EVT, 16> VTs;
       SmallVector<uint64_t, 16> Offsets;
       ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
@@ -3464,9 +3461,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
       unsigned I = 0;
       for (const unsigned NumElts : VectorInfo) {
-        const EVT EltVT = VTs[I];
         // i1 is loaded/stored as i8
-        const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
+        const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
         // If the element is a packed type (ex. v2f16, v4i8, etc) holding
         // multiple elements.
         const unsigned PackingAmt =
@@ -3478,14 +3474,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         SDValue VecAddr = DAG.getObjectPtrOffset(
             dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
 
-        const MaybeAlign PartAlign = [&]() -> MaybeAlign {
-          if (aggregateIsPacked)
-            return Align(1);
-          if (NumElts != 1)
-            return std::nullopt;
-          Align PartAlign = DAG.getEVTAlign(EltVT);
-          return commonAlignment(PartAlign, Offsets[I]);
-        }();
+        const MaybeAlign PartAlign = commonAlignment(ArgAlign, Offsets[I]);
         SDValue P =
             DAG.getLoad(VecVT, dl, Root, VecAddr,
                         MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
@@ -3497,14 +3486,14 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           SDValue Elt = DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
                                                       : ISD::EXTRACT_VECTOR_ELT,
                                     dl, LoadVT, P,
-                                    DAG.getIntPtrConstant(J * PackingAmt, dl));
+                                    DAG.getVectorIdxConstant(J * PackingAmt, dl));
 
           // Extend or truncate the element if necessary (e.g. an i8 is loaded
           // into an i16 register)
           const EVT ExpactedVT = ArgIns[I + J].VT;
-          assert((Elt.getValueType().bitsEq(ExpactedVT) ||
-                  (ExpactedVT.isScalarInteger() &&
-                   Elt.getValueType().isScalarInteger())) &&
+          assert((Elt.getValueType() == ExpactedVT ||
+                  (ExpactedVT.isInteger() &&
+                   Elt.getValueType().isInteger())) &&
                  "Non-integer argument type size mismatch");
           if (ExpactedVT.bitsGT(Elt.getValueType()))
             Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpactedVT,
@@ -3524,33 +3513,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
   return Chain;
 }
 
-// Use byte-store when the param adress of the return value is unaligned.
-// This may happen when the return value is a field of a packed structure.
-static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain,
-                                      uint64_t Offset, EVT ElementType,
-                                      SDValue RetVal, const SDLoc &dl) {
-  // Bit logic only works on integer types
-  if (adjustElementType(ElementType))
-    RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
-
-  // Store each byte
-  for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
-    // Shift the byte to the last byte position
-    SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
-                                   DAG.getConstant(i * 8, dl, MVT::i32));
-    SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
-                               ShiftVal};
-    // Trunc store only the last byte by using
-    //     st.param.b8
-    // The register type can be larger than b8.
-    Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
-                                    DAG.getVTList(MVT::Other), StoreOperands,
-                                    MVT::i8, MachinePointerInfo(), std::nullopt,
-                                    MachineMemOperand::MOStore);
-  }
-  return Chain;
-}
-
 SDValue
 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
                                  bool isVarArg,
@@ -3572,10 +3534,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
   assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
 
-  for (const unsigned I : llvm::seq(VTs.size()))
-    if (const auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
-      VTs[I] = *PromotedVT;
-
   // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
   // 32-bits are sign extended or zero extended, depending on whether
   // they are signed or unsigned types.
@@ -3587,12 +3545,20 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     assert(!PromoteScalarIntegerPTX(RetVal.getValueType()) &&
            "OutVal type should always be legal");
 
-    if (ExtendIntegerRetVal) {
-      RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, MVT::i32, RetVal);
-    } else if (RetVal.getValueSizeInBits() < 16) {
-      // Use 16-bit registers for small load-stores as it's the
-      // smallest general purpose register size supported by NVPTX.
-      RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
+    EVT VTI = VTs[I];
+    if (const auto PromotedVT = PromoteScalarIntegerPTX(VTI))
+      VTI = *PromotedVT;
+
+    const EVT StoreVT =
+        ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+
+    assert((RetVal.getValueType() == StoreVT ||
+            (StoreVT.isInteger() && RetVal.getValueType().isInteger())) &&
+           "Non-integer argument type size mismatch");
+    if (StoreVT.bitsGT(RetVal.getValueType())) {
+      RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, StoreVT, RetVal);
+    } else if (StoreVT.bitsLT(RetVal.getValueType())) {
+      RetVal = DAG.getNode(ISD::TRUNCATE, dl, StoreVT, RetVal);
     }
     return RetVal;
   };
@@ -3601,45 +3567,36 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
   unsigned I = 0;
   for (const unsigned NumElts : VectorInfo) {
-    const Align CurrentAlign = commonAlignment(RetAlign, Offsets[I]);
-    if (NumElts == 1 && RetTy->isAggregateType() &&
-        CurrentAlign < DAG.getEVTAlign(VTs[I])) {
-      Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], VTs[I],
-                                     GetRetVal(I), dl);
-
-      // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
-      // into the graph, so just move on to the next element.
-      I++;
-      continue;
-    }
+    const MaybeAlign CurrentAlign = ExtendIntegerRetVal
+                                        ? MaybeAlign(std::nullopt)
+                                        : commonAlignment(RetAlign, Offsets[I]);
 
-    SmallVector<SDValue, 6> StoreOperands{
-        Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)};
-
-    for (const unsigned J : llvm::seq(NumElts))
-      StoreOperands.push_back(GetRetVal(I + J));
+    SDValue Val;
+    if (NumElts == 1) {
+      Val = GetRetVal(I);
+    } else {
+      SmallVector<SDValue, 6> StoreVals;
+      for (const unsigned J : llvm::seq(NumElts)) {
+        SDValue ValJ = GetRetVal(I + J);
+        if (ValJ.getValueType().isVector())
+          DAG.ExtractVectorElements(ValJ, StoreVals);
+        else
+          StoreVals.push_back(ValJ);
+      }
 
-    NVPTXISD::NodeType Op;
-    switch (NumElts) {
-    case 1:
-      Op = NVPTXISD::StoreRetval;
-      break;
-    case 2:
-      Op = NVPTXISD::StoreRetvalV2;
-      break;
-    case 4:
-      Op = NVPTXISD::StoreRetvalV4;
-      break;
-    default:
-      llvm_unreachable("Invalid vector info.");
+      EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(),
+                                StoreVals.size());
+      Val = DAG.getBuildVector(VT, dl, StoreVals);
     }
 
-    // Adjust type of load/store op if we've extended the scalar
-    // return value.
-    EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
-    Chain = DAG.getMemIntrinsicNode(
-        Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
-        MachinePointerInfo(), CurrentAlign, MachineMemOperand::MOStore);
+    SDValue RetSymbol =
+        DAG.getNode(NVPTXISD::Wrapper, dl, MVT::i32,
+                    DAG.getTargetExternalSymbol("func_retval0", MVT::i32));
+    SDValue Ptr =
+        DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
+
+    Chain = DAG.getStore(Chain, dl, Val, Ptr,
+                         MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
 
     I += NumElts;
   }
@@ -5195,19 +5152,12 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
   case NVPTXISD::StoreParamV2:
     Opcode = NVPTXISD::StoreParamV4;
     break;
-  case NVPTXISD::StoreRetval:
-    Opcode = NVPTXISD::StoreRetvalV2;
-    break;
-  case NVPTXISD::StoreRetvalV2:
-    Opcode = NVPTXISD::StoreRetvalV4;
-    break;
   case NVPTXISD::StoreV2:
     MemVT = ST->getMemoryVT();
     Opcode = NVPTXISD::StoreV4;
     break;
   case NVPTXISD::StoreV4:
   case NVPTXISD::StoreParamV4:
-  case NVPTXISD::StoreRetvalV4:
   case NVPTXISD::StoreV8:
     // PTX doesn't support the next doubling of operands
     return SDValue();
@@ -5276,12 +5226,6 @@ static SDValue PerformStoreParamCombine(SDNode *N,
   return PerformStoreCombineHelper(N, DCI, 3, 1);
 }
 
-static SDValue PerformStoreRetvalCombine(SDNode *N,
-                                         TargetLowering::DAGCombinerInfo &DCI) {
-  // Operands from the 2nd to the last one are the values to be stored
-  return PerformStoreCombineHelper(N, DCI, 2, 0);
-}
-
 /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
 ///
 static SDValue PerformADDCombine(SDNode *N,
@@ -5915,10 +5859,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     case NVPTXISD::LoadV2:
     case NVPTXISD::LoadV4:
       return combineUnpackingMovIntoLoad(N, DCI);
-    case NVPTXISD::StoreRetval:
-    case NVPTXISD::StoreRetvalV2:
-    case NVPTXISD::StoreRetvalV4:
-      return PerformStoreRetvalCombine(N, DCI);
     case NVPTXISD::StoreParam:
     case NVPTXISD::StoreParamV2:
     case NVPTXISD::StoreParamV4:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 0a54a8fd71f32..0bd3b899d5b13 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -102,10 +102,7 @@ enum NodeType : unsigned {
   StoreParamV4,
   StoreParamS32, // to sext and store a <32bit value, not used currently
   StoreParamU32, // to zext and store a <32bit value, not used currently
-  StoreRetval,
-  StoreRetvalV2,
-  StoreRetvalV4,
-  LAST_MEMORY_OPCODE = StoreRetvalV4,
+  LAST_MEMORY_OPCODE = StoreParamU32,
 };
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5979054764647..a7c78809c5c7f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1991,9 +1991,6 @@ def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
 def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
 def SDTCallValProfile : SDTypeProfile<1, 0, []>;
 def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
-def SDTStoreRetvalProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
-def SDTStoreRetvalV2Profile : SDTypeProfile<0, 3, [SDTCisInt<0>]>;
-def SDTStoreRetvalV4Profile : SDTypeProfile<0, 5, [SDTCisInt<0>]>;
 def SDTPseudoUseParamProfile : SDTypeProfile<0, 1, []>;
 def SDTProxyRegProfile : SDTypeProfile<1, 1, []>;
 
@@ -2068,15 +2065,6 @@ def CallVal :
          [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
 def MoveParam :
   SDNode<"NVPTXISD::MoveParam", SDTMoveParamProfile, []>;
-def StoreRetval :
-  SDNode<"NVPTXISD::StoreRetval", SDTStoreRetvalProfile,
-         [SDNPHasChain, SDNPSideEffect]>;
-def StoreRetvalV2 :
-  SDNode<"NVPTXISD::StoreRetvalV2", SDTStoreRetvalV2Profile,
-         [SDNPHasChain, SDNPSideEffect]>;
-def StoreRetvalV4 :
-  SDNode<"NVPTXISD::StoreRetvalV4", SDTStoreRetvalV4Profile,
-         [SDNPHasChain, SDNPSideEffect]>;
 def PseudoUseParam :
   SDNode<"NVPTXISD::PseudoUseParam", SDTPseudoUseParamProfile,
          [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
@@ -2153,25 +2141,6 @@ let mayStore = true in {
                           " \t[param$a$b], {{$val1, $val2, $val3, $val4}};",
                           []>;
   }
-
-  class StoreRetvalInst<NVPTXRegClass regclass, string opstr> :
-        NVPTXInst<(outs), (ins regclass:$val, Offseti32imm:$a),
-                  !strconcat("st.param", opstr, " \t[func_retval0$a], $val;"),
-                  []>;
-
-  class StoreRetvalV2Inst<NVPTXRegClass regclass, string opstr> :
-        NVPTXInst<(outs), (ins regclass:$val, regclass:$val2, Offseti32imm:$a),
-                  !strconcat("st.param.v2", opstr,
-                             " \t[func_retval0$a], {{$val, $val2}};"),
-                  []>;
-
-  class StoreRetvalV4Inst<NVPTXRegClass regclass, string opstr> :
-        NVPTXInst<(outs),
-                  (ins regclass:$val, regclass:$val2, regclass:$val3,
-                       regclass:$val4, Offseti32imm:$a),
-                  !strconcat("st.param.v4", opstr,
-                             " \t[func_retval0$a], {{$val, $val2, $val3, $val4}};"),
-                  []>;
 }
 
 let isCall=1 in {
@@ -2230,20 +2199,6 @@ defm StoreParamV2F64  : StoreParamV2Inst<B64, f64imm, ".b64">;
 
 defm StoreParamV4F32  : StoreParamV4Inst<B32, f32imm, ".b32">;
 
-def StoreRetvalI64    : StoreRetvalInst<B64, ".b64">;
-def StoreRetvalI32    : StoreRetvalInst<B32, ".b32">;
-def StoreRetvalI16    : StoreRetvalInst<B16, ".b16">;
-def StoreRetvalI8     : StoreRetvalInst<B16, ".b8">;
-def StoreRetvalI8TruncI32 : StoreRetvalInst<B32, ".b8">;
-def StoreRetvalI8TruncI64 : StoreRetvalInst<B64, ".b8">;
-def StoreRetvalV2I64  : StoreRetvalV2Inst<B64, ".b64">;
-def StoreRetvalV2I32  : StoreRetvalV2Inst<B32, ".b32">;
-def StoreRetvalV2I16  : StoreRetvalV2Inst<B16, ".b16">;
-def StoreRetvalV2I8   : StoreRetvalV2Inst<B16, ".b8">;
-def StoreRetvalV4I32  : StoreRetvalV4Inst<B32, ".b32">;
-def StoreRetvalV4I16  : StoreRetvalV4Inst<B16, ".b16">;
-def StoreRetvalV4I8   : StoreRetvalV4Inst<B16, ".b8">;
-
 def CallArgBeginInst : NVPTXInst<(outs), (ins), "(", [(CallArgBegin)]>;
 def CallArgEndInst1  : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;
 def CallArgEndInst0  : NVPTXInst<(outs), (ins), ")"...
[truncated]

@github-actions
Copy link

github-actions bot commented Jun 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8 branch from 3aec38e to 14aa0a0 Compare June 24, 2025 20:38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes how we return true values. While any non-zero value is valid, this is the change that will be observable by the users.

How hard would that be to preserve the current behavior?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how this would be observable by the users. A call to a function which returns an i1 will still be lowered such that the high bits do not matter. In practice I think -1 is a bit preferable as it could allow for other folds such as using set in this case https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-set

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with -1 as the representation of the bool true, but I also don't want it to become a showstopper for existing LLVM users. Existing CUDA code loves to bitcast things rather indiscriminately. Bitcasting between char/*int8_t/bool happens. I'm not saying that all of them are valid, but they do happen, and this change will affect them.
Perhaps it will be fine, but there's also a non-zero chance that I'll have users that will complain and we would not be able to debug and fix their stuff quickly.

We can land the patch as is. If the issues pop-up, we can revert and then rework it to make it a noop for in memory representation of bools. Or we can make it noop up-front if it's relatively easy, and address this issue separately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change only impacts how i1 return values are lowered, not the in-memory representation of these types in general. As far as I can tell, this would only be observable if someone were to say bitcast a function pointer to change the return type. Even then, it wouldn't be visible for a clang-cuda program as the front-end adds the zeroext attribute to parameters and return values, and this will force us to continue to use 1.

https://godbolt.org/z/x6odW4den

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is due to the setting we have in NVPTXISelLowering():

setBooleanContents(ZeroOrNegativeOneBooleanContent);
setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);

However, in the ISA docs, the description of semantics for instructions like set suggest we should use ZeroOrOneBooleanContent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a result of this:

// anyext i1
def : Pat<(i16 (anyext i1:$a)), (SELP_b16ii -1, 0, $a)>;
def : Pat<(i32 (anyext i1:$a)), (SELP_b32ii -1, 0, $a)>;
def : Pat<(i64 (anyext i1:$a)), (SELP_b64ii -1, 0, $a)>;

Comment on lines 40 to 43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably that's where the #145552 will help to use the constants directly, without intermediate register moves.

Comment on lines +78 to 80
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yay! We're finally loading v2i8 correctly.
Nit: the and is not needed here, as ld.param.v2.b8 already zero-extended the values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the code gen here is still quite poor. I'll look into this separately.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8 branch from 14aa0a0 to dce994e Compare June 25, 2025 15:40
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8 branch from cca3b27 to fefbb39 Compare June 25, 2025 19:37
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8 branch from fefbb39 to ca2f50c Compare June 26, 2025 02:19
Copy link
Contributor

@Prince781 Prince781 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some questions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is due to the setting we have in NVPTXISelLowering():

setBooleanContents(ZeroOrNegativeOneBooleanContent);
setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);

However, in the ISA docs, the description of semantics for instructions like set suggest we should use ZeroOrOneBooleanContent.

@AlexMaclean AlexMaclean merged commit f03782d into llvm:main Jun 27, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants