Skip to content

Conversation

@dakersnar
Copy link
Contributor

@dakersnar dakersnar commented Sep 17, 2025

This backend support will allow the LoadStoreVectorizer, in certain cases, to fill in gaps when creating load/store vectors and generate LLVM masked load/stores (https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics). To accomplish this, changes are separated into two parts. This first part has the backend lowering and TTI changes, and a follow up PR will have the LSV generate these intrinsics: #159388.

In this backend change, Masked Loads get lowered to PTX with #pragma "used_bytes_mask" [mask]; (https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask). And Masked Stores get lowered to PTX using the new sink symbol syntax (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st).

TTI Changes

TTI changes are needed because NVPTX only supports masked loads/stores with constant masks. ScalarizeMaskedMemIntrin.cpp is adjusted to check that the mask is constant and pass that result into the TTI check. Behavior shouldn't change for non-NVPTX targets, which do not care whether the mask is variable or constant when determining legality, but all TTI files that implement these API need to be updated.

Masked store lowering implementation details

If the masked stores make it to the NVPTX backend without being scalarized, they are handled by the following:

  • NVPTXISelLowering.cpp - Sets up a custom operation action and handles it in lowerMSTORE. Similar handling to normal store vectors, except we read the mask and place a sentinel register $noreg in each position where the mask reads as false.

For example,

t10: v8i1 = BUILD_VECTOR Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>, Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>
t11: ch = masked_store<(store unknown-size into %ir.lsr.iv28, align 32, addrspace 1)> t5:1, t5, t7, undef:i64, t10

->

STV_i32_v8 killed %13:int32regs, $noreg, $noreg, killed %16:int32regs, killed %17:int32regs, $noreg, $noreg, killed %20:int32regs, 0, 0, 1, 8, 0, 32, %4:int64regs, 0, debug-location !18 :: (store unknown-size into %ir.lsr.iv28, align 32, addrspace 1);

  • NVPTXInstInfo.td - changes the definition of store vectors to allow for a mix of sink symbols and registers.
  • NVPXInstPrinter.h/.cpp - Handles the $noreg case by printing "_".

Masked load lowering implementation details

Masked loads are routed to normal PTX loads, with one difference: a #pragma "used_bytes_mask" is emitted before the load instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask). To accomplish this, a new operand is added to every NVPTXISD Load type representing this mask.

  • NVPTXISelLowering.h/.cpp - Masked loads are converted into normal NVPTXISD loads with a mask operand in two ways. 1) In type legalization through replaceLoadVector, which is the normal path, and 2) through LowerMLOAD, to handle the legal vector types (v2f16/v2bf16/v2i16/v4i8/v2f32) that will not be type legalized. Both share the same convertMLOADToLoadWithUsedBytesMask helper. Both default this operand to UINT32_MAX, representing all bytes on. For the latter, we need a new NVPTXISD::MLoadV1 type to represent that edge case because we cannot put the used bytes mask operand on a generic LoadSDNode.
  • NVPTXISelDAGToDAG.cpp - Extract used bytes mask from loads, add them to created machine instructions.
  • NVPTXInstPrinter.h/.cpp - Print the pragma when the used bytes mask isn't all ones.
  • NVPTXForwardParams.cpp, NVPTXReplaceImageHandles.cpp - Update manual indexing of load operands to account for new operand.
  • NVPTXInsrtInfo.td, NVPTXIntrinsics.td - Add the used bytes mask to the MI definitions.
  • NVPTXTagInvariantLoads.cpp - Ensure that masked loads also get tagged as invariant.

Some generic changes that are needed:

  • LegalizeVectorTypes.cpp - Ensure flags are preserved when splitting masked loads.
  • SelectionDAGBuilder.cpp - Preserve MD_invariant_load on masked load SDNode creation

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-backend-hexagon
@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-backend-risc-v
@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-backend-arm

Author: Drew Kersnar (dakersnar)

Changes

This backend support will allow the LoadStoreVectorizer, in certain cases, to fill in gaps when creating store vectors and generate LLVM masked stores (https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics), which get lowered to PTX using the new sink symbol syntax (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st). To accomplish this, changes are separated into two parts. This first part has the backend lowering and TTI changes, and a follow up PR will have the LSV generate these intrinsics: [INSERT]

TTI changes are needed because NVPTX only supports masked stores with constant masks. ScalarizeMaskedMemIntrin.cpp is adjusted to check that the mask is constant and pass that result into the TTI check. Behavior shouldn't change for non-NVPTX targets, which do not care whether the mask is variable or constant when determining legality.

If the masked stores make it to the NVPTX backend without being scalarized, they are handled by the following:

  • NVPTXISelLowering.cpp - Sets up a custom operation action and handles it in lowerMSTORE. Similar handling to normal store vectors, except we read the mask and place a sentinel register $noreg in each position where the mask reads as false.

For example,

t10: v8i1 = BUILD_VECTOR Constant:i1&lt;-1&gt;, Constant:i1&lt;0&gt;, Constant:i1&lt;0&gt;, Constant:i1&lt;-1&gt;, Constant:i1&lt;-1&gt;, Constant:i1&lt;0&gt;, Constant:i1&lt;0&gt;, Constant:i1&lt;-1&gt;
t11: ch = masked_store&lt;(store unknown-size into %ir.lsr.iv28, align 32, addrspace 1)&gt; t5:1, t5, t7, undef:i64, t10

-&gt;

STV_i32_v8 killed %13:int32regs, $noreg, $noreg, killed %16:int32regs, killed %17:int32regs, $noreg, $noreg, killed %20:int32regs, 0, 0, 1, 8, 0, 32, %4:int64regs, 0, debug-location !18 :: (store unknown-size into %ir.lsr.iv28, align 32, addrspace 1);

  • NVPTXInstInfo.td - changes the definition of store vectors to allow for a mix of sink symbols and registers.
  • NVPXInstPrinter.h/.cpp - Handles the $noreg case by printing "_".

Patch is 35.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159387.diff

20 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+6-2)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1-1)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+2-2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/ARM/ARMTargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp (+1-1)
  • (modified) llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+10)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+88-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+7-3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp (+26)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+3)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/VE/VETargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.cpp (+1-1)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp (+2-1)
  • (added) llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll (+56)
  • (added) llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll (+319)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 41ff54f0781a2..e7886537379bc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -810,9 +810,13 @@ class TargetTransformInfo {
   LLVM_ABI AddressingModeKind
   getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
 
-  /// Return true if the target supports masked store.
+  /// Return true if the target supports masked store. A value of false for
+  /// IsMaskConstant indicates that the mask could either be variable or
+  /// constant. This is for targets that only support masked store with a
+  /// constant mask.
   LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                   unsigned AddressSpace) const;
+                                   unsigned AddressSpace,
+                                   bool IsMaskConstant = false) const;
   /// Return true if the target supports masked load.
   LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                                   unsigned AddressSpace) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 566e1cf51631a..33705e1dd5f98 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace) const {
+                                  unsigned AddressSpace, bool IsMaskConstant) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 09b50c5270e57..838712e55d0dd 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 }
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
-                                             unsigned AddressSpace) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+                                             unsigned AddressSpace, bool IsMaskConstant) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe2e849258e3f..e40631d88748c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 0810c5532ed91..ee4f72552d90d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
                          unsigned AddressSpace) const override;
 
   bool isLegalMaskedStore(Type *DataTy, Align Alignment,
-                          unsigned AddressSpace) const override {
+                          unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
   }
 
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 171e2949366ad..c989bf77a9d51 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 }
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
-                                        unsigned /*AddressSpace*/) const {
+                                        unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index dbf16c99c314c..e2674bb9cdad7 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant) const override;
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
 
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09935330..dc6b631d33451 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -392,6 +392,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
   }
 }
 
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+                                                 raw_ostream &O,
+                                                 const char *Modifier) {
+  const MCOperand &Op = MI->getOperand(OpNum);
+  if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+    O << "_";
+  else
+    printOperand(MI, OpNum, O);
+}
+
 void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
                                       raw_ostream &O) {
   int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..d373668aa591f 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                     StringRef Modifier = {});
   void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
                        StringRef Modifier = {});
+  void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
+                                 const char *Modifier = nullptr);
   void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
   void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..6810b6008d8cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -753,7 +753,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
   for (MVT VT : MVT::fixedlen_vector_valuetypes())
     if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
-      setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+      setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
 
   // Custom legalization for LDU intrinsics.
   // TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -2869,6 +2869,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
   return Or;
 }
 
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+  SDNode *N = Op.getNode();
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Val = N->getOperand(1);
+  SDValue BasePtr = N->getOperand(2);
+  SDValue Offset = N->getOperand(3);
+  SDValue Mask = N->getOperand(4);
+
+  SDLoc DL(N);
+  EVT ValVT = Val.getValueType();
+  MemSDNode *MemSD = cast<MemSDNode>(N);
+  assert(ValVT.isVector() && "Masked vector store must have vector type");
+  assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+         "Unexpected alignment for masked store");
+
+  unsigned Opcode = 0;
+  switch (ValVT.getSimpleVT().SimpleTy) {
+  default:
+    llvm_unreachable("Unexpected masked vector store type");
+  case MVT::v4i64:
+  case MVT::v4f64: {
+    Opcode = NVPTXISD::StoreV4;
+    break;
+  }
+  case MVT::v8i32:
+  case MVT::v8f32: {
+    Opcode = NVPTXISD::StoreV8;
+    break;
+  }
+  }
+
+  SmallVector<SDValue, 8> Ops;
+
+  // Construct the new SDNode. First operand is the chain.
+  Ops.push_back(Chain);
+
+  // The next N operands are the values to store. Encode the mask into the
+  // values using the sentinel register 0 to represent a masked-off element.
+  assert(Mask.getValueType().isVector() &&
+         Mask.getValueType().getVectorElementType() == MVT::i1 &&
+         "Mask must be a vector of i1");
+  assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+         "Mask expected to be a BUILD_VECTOR");
+  assert(Mask.getValueType().getVectorNumElements() ==
+             ValVT.getVectorNumElements() &&
+         "Mask size must be the same as the vector size");
+  for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
+    assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+           "Mask elements must be constants");
+    if (Mask->getConstantOperandVal(I) == 0) {
+      // Append a sentinel register 0 to the Ops vector to represent a masked
+      // off element, this will be handled in tablegen
+      Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+                                    ValVT.getVectorElementType()));
+    } else {
+      // Extract the element from the vector to store
+      SDValue ExtVal =
+          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+                      Val, DAG.getIntPtrConstant(I, DL));
+      Ops.push_back(ExtVal);
+    }
+  }
+
+  // Next, the pointer operand.
+  Ops.push_back(BasePtr);
+
+  // Finally, the offset operand. We expect this to always be undef, and it will
+  // be ignored in lowering, but to mirror the handling of the other vector
+  // store instructions we include it in the new SDNode.
+  assert(Offset.getOpcode() == ISD::UNDEF &&
+         "Offset operand expected to be undef");
+  Ops.push_back(Offset);
+
+  SDValue NewSt =
+      DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+                              MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+  return NewSt;
+}
+
 SDValue
 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -2905,6 +2986,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVECREDUCE(Op, DAG);
   case ISD::STORE:
     return LowerSTORE(Op, DAG);
+  case ISD::MSTORE: {
+    assert(STI.has256BitVectorLoadStore(
+               cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+           "Masked store vector not supported on subtarget.");
+    return lowerMSTORE(Op, DAG);
+  }
   case ISD::LOAD:
     return LowerLOAD(Op, DAG);
   case ISD::SHL_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4e38e026e6bda..a8d6ff60c9b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1500,6 +1500,10 @@ def ADDR : Operand<pAny> {
   let MIOperandInfo = (ops ADDR_base, i32imm);
 }
 
+def RegOrSink : Operand<Any> {
+  let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
 def AtomicCode : Operand<i32> {
   let PrintMethod = "printAtomicCode";
 }
@@ -1806,7 +1810,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
     "\t[$addr], {{$src1, $src2}};">;
   def _v4 : NVPTXInst<
     (outs),
-    (ins O:$src1, O:$src2, O:$src3, O:$src4,
+    (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
          AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
          ADDR:$addr),
     "st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1814,8 +1818,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
   if support_v8 then
     def _v8 : NVPTXInst<
       (outs),
-      (ins O:$src1, O:$src2, O:$src3, O:$src4,
-           O:$src5, O:$src6, O:$src7, O:$src8,
+      (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+           RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
            AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
            ADDR:$addr),
       "st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index f4f89613b358d..88b13cb38d67b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -597,6 +597,32 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
   return nullptr;
 }
 
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+                                      unsigned AddrSpace, bool IsMaskConstant) const {
+
+  if (!IsMaskConstant)
+    return false;
+  
+  //  We currently only support this feature for 256-bit vectors, so the
+  //  alignment must be at least 32
+  if (Alignment < 32)
+    return false;
+
+  if (!ST->has256BitVectorLoadStore(AddrSpace))
+    return false;
+
+  auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+  if (!VTy)
+    return false;
+
+  auto *ScalarTy = VTy->getScalarType();
+  if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+      (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
+    return true;
+
+  return false;
+}
+
 unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
   // 256 bit loads/stores are currently only supported for global address space
   if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b32d931bd3074..9e5500966fe10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddrSpace, bool IsMaskConstant) const override;
+
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 47e0a250d285a..80f10eb29bca4 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,7 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..4971d9148b747 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,7 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3d8d0a236a3c1..b16a2a593df03 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                    unsigned AddressSpace) const {
+                                    unsigned AddressSpace, bool IsMaskConstant) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 133b3668a46c8..7f6ff65d427ed 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,7 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant = false) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 42d6680c3cb7d..412c1b04cdf3a 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1137,7 +1137,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getArgOperand(0)->getType(),
               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
               cast<PointerType>(CI->getArgOperand(1)->getType())
-                  ->getAddressSpace()))
+                  ->getAddressSpace(),
+              isConstantIntVector(CI->getArgOperand(3))))
         return false;
       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..7d8f65b25bb02
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<9>;
+; CHECK-NEXT:    .reg .b16 %rs<9>;
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT:    ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT:    and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT:    ld.param.b8 %rs5, [global_var...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Drew Kersnar (dakersnar)

Changes

This backend support will allow the LoadStoreVectorizer, in certain cases, to fill in gaps when creating store vectors and generate LLVM masked stores (https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics), which get lowered to PTX using the new sink symbol syntax (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st). To accomplish this, changes are separated into two parts. This first part has the backend lowering and TTI changes, and a follow up PR will have the LSV generate these intrinsics: [INSERT]

TTI changes are needed because NVPTX only supports masked stores with constant masks. ScalarizeMaskedMemIntrin.cpp is adjusted to check that the mask is constant and pass that result into the TTI check. Behavior shouldn't change for non-NVPTX targets, which do not care whether the mask is variable or constant when determining legality.

If the masked stores make it to the NVPTX backend without being scalarized, they are handled by the following:

  • NVPTXISelLowering.cpp - Sets up a custom operation action and handles it in lowerMSTORE. Similar handling to normal store vectors, except we read the mask and place a sentinel register $noreg in each position where the mask reads as false.

For example,

t10: v8i1 = BUILD_VECTOR Constant:i1&lt;-1&gt;, Constant:i1&lt;0&gt;, Constant:i1&lt;0&gt;, Constant:i1&lt;-1&gt;, Constant:i1&lt;-1&gt;, Constant:i1&lt;0&gt;, Constant:i1&lt;0&gt;, Constant:i1&lt;-1&gt;
t11: ch = masked_store&lt;(store unknown-size into %ir.lsr.iv28, align 32, addrspace 1)&gt; t5:1, t5, t7, undef:i64, t10

-&gt;

STV_i32_v8 killed %13:int32regs, $noreg, $noreg, killed %16:int32regs, killed %17:int32regs, $noreg, $noreg, killed %20:int32regs, 0, 0, 1, 8, 0, 32, %4:int64regs, 0, debug-location !18 :: (store unknown-size into %ir.lsr.iv28, align 32, addrspace 1);

  • NVPTXInstInfo.td - changes the definition of store vectors to allow for a mix of sink symbols and registers.
  • NVPXInstPrinter.h/.cpp - Handles the $noreg case by printing "_".

Patch is 35.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159387.diff

20 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+6-2)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1-1)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+2-2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/ARM/ARMTargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp (+1-1)
  • (modified) llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+10)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+88-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+7-3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp (+26)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+3)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/VE/VETargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.cpp (+1-1)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp (+2-1)
  • (added) llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll (+56)
  • (added) llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll (+319)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 41ff54f0781a2..e7886537379bc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -810,9 +810,13 @@ class TargetTransformInfo {
   LLVM_ABI AddressingModeKind
   getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
 
-  /// Return true if the target supports masked store.
+  /// Return true if the target supports masked store. A value of false for
+  /// IsMaskConstant indicates that the mask could either be variable or
+  /// constant. This is for targets that only support masked store with a
+  /// constant mask.
   LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                   unsigned AddressSpace) const;
+                                   unsigned AddressSpace,
+                                   bool IsMaskConstant = false) const;
   /// Return true if the target supports masked load.
   LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                                   unsigned AddressSpace) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 566e1cf51631a..33705e1dd5f98 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace) const {
+                                  unsigned AddressSpace, bool IsMaskConstant) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 09b50c5270e57..838712e55d0dd 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 }
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
-                                             unsigned AddressSpace) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+                                             unsigned AddressSpace, bool IsMaskConstant) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe2e849258e3f..e40631d88748c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 0810c5532ed91..ee4f72552d90d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
                          unsigned AddressSpace) const override;
 
   bool isLegalMaskedStore(Type *DataTy, Align Alignment,
-                          unsigned AddressSpace) const override {
+                          unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
   }
 
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 171e2949366ad..c989bf77a9d51 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 }
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
-                                        unsigned /*AddressSpace*/) const {
+                                        unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index dbf16c99c314c..e2674bb9cdad7 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant) const override;
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
 
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09935330..dc6b631d33451 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -392,6 +392,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
   }
 }
 
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+                                                 raw_ostream &O,
+                                                 const char *Modifier) {
+  const MCOperand &Op = MI->getOperand(OpNum);
+  if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+    O << "_";
+  else
+    printOperand(MI, OpNum, O);
+}
+
 void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
                                       raw_ostream &O) {
   int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..d373668aa591f 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                     StringRef Modifier = {});
   void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
                        StringRef Modifier = {});
+  void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
+                                 const char *Modifier = nullptr);
   void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
   void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..6810b6008d8cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -753,7 +753,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
   for (MVT VT : MVT::fixedlen_vector_valuetypes())
     if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
-      setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+      setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
 
   // Custom legalization for LDU intrinsics.
   // TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -2869,6 +2869,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
   return Or;
 }
 
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+  SDNode *N = Op.getNode();
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Val = N->getOperand(1);
+  SDValue BasePtr = N->getOperand(2);
+  SDValue Offset = N->getOperand(3);
+  SDValue Mask = N->getOperand(4);
+
+  SDLoc DL(N);
+  EVT ValVT = Val.getValueType();
+  MemSDNode *MemSD = cast<MemSDNode>(N);
+  assert(ValVT.isVector() && "Masked vector store must have vector type");
+  assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+         "Unexpected alignment for masked store");
+
+  unsigned Opcode = 0;
+  switch (ValVT.getSimpleVT().SimpleTy) {
+  default:
+    llvm_unreachable("Unexpected masked vector store type");
+  case MVT::v4i64:
+  case MVT::v4f64: {
+    Opcode = NVPTXISD::StoreV4;
+    break;
+  }
+  case MVT::v8i32:
+  case MVT::v8f32: {
+    Opcode = NVPTXISD::StoreV8;
+    break;
+  }
+  }
+
+  SmallVector<SDValue, 8> Ops;
+
+  // Construct the new SDNode. First operand is the chain.
+  Ops.push_back(Chain);
+
+  // The next N operands are the values to store. Encode the mask into the
+  // values using the sentinel register 0 to represent a masked-off element.
+  assert(Mask.getValueType().isVector() &&
+         Mask.getValueType().getVectorElementType() == MVT::i1 &&
+         "Mask must be a vector of i1");
+  assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+         "Mask expected to be a BUILD_VECTOR");
+  assert(Mask.getValueType().getVectorNumElements() ==
+             ValVT.getVectorNumElements() &&
+         "Mask size must be the same as the vector size");
+  for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
+    assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+           "Mask elements must be constants");
+    if (Mask->getConstantOperandVal(I) == 0) {
+      // Append a sentinel register 0 to the Ops vector to represent a masked
+      // off element, this will be handled in tablegen
+      Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+                                    ValVT.getVectorElementType()));
+    } else {
+      // Extract the element from the vector to store
+      SDValue ExtVal =
+          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+                      Val, DAG.getIntPtrConstant(I, DL));
+      Ops.push_back(ExtVal);
+    }
+  }
+
+  // Next, the pointer operand.
+  Ops.push_back(BasePtr);
+
+  // Finally, the offset operand. We expect this to always be undef, and it will
+  // be ignored in lowering, but to mirror the handling of the other vector
+  // store instructions we include it in the new SDNode.
+  assert(Offset.getOpcode() == ISD::UNDEF &&
+         "Offset operand expected to be undef");
+  Ops.push_back(Offset);
+
+  SDValue NewSt =
+      DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+                              MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+  return NewSt;
+}
+
 SDValue
 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -2905,6 +2986,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVECREDUCE(Op, DAG);
   case ISD::STORE:
     return LowerSTORE(Op, DAG);
+  case ISD::MSTORE: {
+    assert(STI.has256BitVectorLoadStore(
+               cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+           "Masked store vector not supported on subtarget.");
+    return lowerMSTORE(Op, DAG);
+  }
   case ISD::LOAD:
     return LowerLOAD(Op, DAG);
   case ISD::SHL_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4e38e026e6bda..a8d6ff60c9b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1500,6 +1500,10 @@ def ADDR : Operand<pAny> {
   let MIOperandInfo = (ops ADDR_base, i32imm);
 }
 
+def RegOrSink : Operand<Any> {
+  let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
 def AtomicCode : Operand<i32> {
   let PrintMethod = "printAtomicCode";
 }
@@ -1806,7 +1810,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
     "\t[$addr], {{$src1, $src2}};">;
   def _v4 : NVPTXInst<
     (outs),
-    (ins O:$src1, O:$src2, O:$src3, O:$src4,
+    (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
          AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
          ADDR:$addr),
     "st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1814,8 +1818,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
   if support_v8 then
     def _v8 : NVPTXInst<
       (outs),
-      (ins O:$src1, O:$src2, O:$src3, O:$src4,
-           O:$src5, O:$src6, O:$src7, O:$src8,
+      (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+           RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
            AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
            ADDR:$addr),
       "st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index f4f89613b358d..88b13cb38d67b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -597,6 +597,32 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
   return nullptr;
 }
 
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+                                      unsigned AddrSpace, bool IsMaskConstant) const {
+
+  if (!IsMaskConstant)
+    return false;
+  
+  //  We currently only support this feature for 256-bit vectors, so the
+  //  alignment must be at least 32
+  if (Alignment < 32)
+    return false;
+
+  if (!ST->has256BitVectorLoadStore(AddrSpace))
+    return false;
+
+  auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+  if (!VTy)
+    return false;
+
+  auto *ScalarTy = VTy->getScalarType();
+  if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+      (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
+    return true;
+
+  return false;
+}
+
 unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
   // 256 bit loads/stores are currently only supported for global address space
   if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b32d931bd3074..9e5500966fe10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddrSpace, bool IsMaskConstant) const override;
+
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 47e0a250d285a..80f10eb29bca4 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,7 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..4971d9148b747 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,7 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3d8d0a236a3c1..b16a2a593df03 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                    unsigned AddressSpace) const {
+                                    unsigned AddressSpace, bool IsMaskConstant) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 133b3668a46c8..7f6ff65d427ed 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,7 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant = false) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 42d6680c3cb7d..412c1b04cdf3a 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1137,7 +1137,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getArgOperand(0)->getType(),
               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
               cast<PointerType>(CI->getArgOperand(1)->getType())
-                  ->getAddressSpace()))
+                  ->getAddressSpace(),
+              isConstantIntVector(CI->getArgOperand(3))))
         return false;
       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..7d8f65b25bb02
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<9>;
+; CHECK-NEXT:    .reg .b16 %rs<9>;
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT:    ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT:    and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT:    ld.param.b8 %rs5, [global_var...
[truncated]

@github-actions
Copy link

github-actions bot commented Sep 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@dakersnar dakersnar changed the title [NVPTX] Lower LLVM masked vector stores to PTX using new sink symbol syntax [NVPTX] Lower LLVM masked vector loads and stores to PTX Oct 17, 2025
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Oct 17, 2025
@dakersnar
Copy link
Contributor Author

I've added masked load lowering support and updated the description to explain the changes in each file in detail. Hope this makes reviewing easier, but let me know if anyone has any questions.

@dakersnar dakersnar requested a review from Prince781 October 17, 2025 23:55
@dakersnar dakersnar force-pushed the github/dkersnar/masked-store-lowering branch from dc244cb to abbf0ba Compare October 21, 2025 18:50
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall with a few nits.

@dakersnar dakersnar force-pushed the github/dkersnar/masked-store-lowering branch from 5f54ae6 to 0ad2f80 Compare October 27, 2025 21:51
@Prince781 Prince781 self-requested a review November 1, 2025 19:00
Copy link
Contributor

@Prince781 Prince781 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Do you think we should have tests using vectors with non-byte-aligned types (eg. <4 x i5>)?

@dakersnar
Copy link
Contributor Author

LGTM. Do you think we should have tests using vectors with non-byte-aligned types (eg. <4 x i5>)?

Now that I think about it, I think it would make more sense to add an additional filter to masked load legality check to avoid sub-byte element types. Our used bytes mask pragma is not able to support them, and I don't think we will benefit from their creation in the vectorizer.

@dakersnar
Copy link
Contributor Author

@Artem-B Let me know if you have any other comments, or thoughts on my responses to your feedback so far. Also, if you get the chance to look at the corresponding LSV change again, that would be great too: #159388. Hoping for 1 approval on each from a non-NVIDIA reviewer.

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@dakersnar dakersnar force-pushed the github/dkersnar/masked-store-lowering branch from c9e0b9f to 8013458 Compare November 21, 2025 23:01
@github-actions
Copy link

github-actions bot commented Nov 21, 2025

🐧 Linux x64 Test Results

  • 186553 tests passed
  • 4882 tests skipped

@dakersnar dakersnar merged commit 17852de into llvm:main Nov 25, 2025
10 checks passed
@dakersnar
Copy link
Contributor Author

Missed something on windows, there are test failures here: https://lab.llvm.org/buildbot/#/builders/155/builds/15135/steps/7/logs/stdio

I need to add an mattr argument for the sm_90 version of the test commands. I can put up a quick fix now.

@dakersnar
Copy link
Contributor Author

Fix up here: #169535

@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 25, 2025

LLVM Buildbot has detected a new failure on builder llvm-nvptx64-nvidia-ubuntu running on as-builder-7 while building llvm at step 6 "test-build-unified-tree-check-llvm".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/160/builds/28851

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-llvm) failure: test (failure)
******************** TEST 'LLVM :: CodeGen/NVPTX/masked-store-variable-mask.ll' FAILED ********************
Exit Code: 255

Command Output (stdout):
--
# RUN: at line 2
/home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/build/bin/llc < /home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/llvm-project/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | /home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/build/bin/FileCheck /home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/llvm-project/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll -check-prefixes=CHECK
# executed command: /home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/build/bin/llc -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88
# executed command: /home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/build/bin/FileCheck /home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/llvm-project/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll -check-prefixes=CHECK
# RUN: at line 3
/home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/build/bin/llc < /home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/llvm-project/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | /usr/local/cuda/bin/ptxas -c - -arch=sm_100
# executed command: /home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/build/bin/llc -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88
# executed command: /usr/local/cuda/bin/ptxas -c - -arch=sm_100
# .---command stderr------------
# | ptxas /home/buildbot/worker/as-builder-7/llvm-nvptx64-nvidia-ubuntu/build/lit-tmp-be3tr2oa/tmpxft_002f2b98_00000000-0_stdin, line 5; fatal   : Unsupported .version 8.8; current version is '8.7'
# | ptxas fatal   : Ptx assembly aborted due to errors
# `-----------------------------
# error: command failed with exit status: 255

--

********************


@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 25, 2025

LLVM Buildbot has detected a new failure on builder llvm-nvptx-nvidia-ubuntu running on as-builder-7 while building llvm at step 6 "test-build-unified-tree-check-llvm".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/180/builds/28991

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-llvm) failure: test (failure)
******************** TEST 'LLVM :: CodeGen/NVPTX/masked-store-variable-mask.ll' FAILED ********************
Exit Code: 255

Command Output (stdout):
--
# RUN: at line 2
/home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/build/bin/llc < /home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/llvm-project/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | /home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/build/bin/FileCheck /home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/llvm-project/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll -check-prefixes=CHECK
# executed command: /home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/build/bin/llc -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88
# executed command: /home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/build/bin/FileCheck /home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/llvm-project/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll -check-prefixes=CHECK
# RUN: at line 3
/home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/build/bin/llc < /home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/llvm-project/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | /usr/local/cuda/bin/ptxas -c - -arch=sm_100
# executed command: /home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/build/bin/llc -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88
# executed command: /usr/local/cuda/bin/ptxas -c - -arch=sm_100
# .---command stderr------------
# | ptxas /home/buildbot/worker/as-builder-7/llvm-nvptx-nvidia-ubuntu/build/lit-tmp-cpgrz9kk/tmpxft_0031a72b_00000000-0_stdin, line 5; fatal   : Unsupported .version 8.8; current version is '8.7'
# | ptxas fatal   : Ptx assembly aborted due to errors
# `-----------------------------
# error: command failed with exit status: 255

--

********************


augusto2112 pushed a commit to augusto2112/llvm-project that referenced this pull request Dec 3, 2025
This backend support will allow the LoadStoreVectorizer, in certain
cases, to fill in gaps when creating load/store vectors and generate
LLVM masked load/stores
(https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics). To
accomplish this, changes are separated into two parts. This first part
has the backend lowering and TTI changes, and a follow up PR will have
the LSV generate these intrinsics:
llvm#159388.

In this backend change, Masked Loads get lowered to PTX with `#pragma
"used_bytes_mask" [mask];`
(https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask).
And Masked Stores get lowered to PTX using the new sink symbol syntax
(https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st).

# TTI Changes
TTI changes are needed because NVPTX only supports masked loads/stores
with _constant_ masks. `ScalarizeMaskedMemIntrin.cpp` is adjusted to
check that the mask is constant and pass that result into the TTI check.
Behavior shouldn't change for non-NVPTX targets, which do not care
whether the mask is variable or constant when determining legality, but
all TTI files that implement these API need to be updated.

# Masked store lowering implementation details
If the masked stores make it to the NVPTX backend without being
scalarized, they are handled by the following:
* `NVPTXISelLowering.cpp` - Sets up a custom operation action and
handles it in lowerMSTORE. Similar handling to normal store vectors,
except we read the mask and place a sentinel register `$noreg` in each
position where the mask reads as false.

For example, 
```
t10: v8i1 = BUILD_VECTOR Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>, Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>
t11: ch = masked_store<(store unknown-size into %ir.lsr.iv28, align 32, addrspace 1)> t5:1, t5, t7, undef:i64, t10

->

STV_i32_v8 killed %13:int32regs, $noreg, $noreg, killed %16:int32regs, killed %17:int32regs, $noreg, $noreg, killed %20:int32regs, 0, 0, 1, 8, 0, 32, %4:int64regs, 0, debug-location !18 :: (store unknown-size into %ir.lsr.iv28, align 32, addrspace 1);

```

* `NVPTXInstInfo.td` - changes the definition of store vectors to allow
for a mix of sink symbols and registers.
* `NVPXInstPrinter.h/.cpp` - Handles the `$noreg` case by printing "_".

# Masked load lowering implementation details
Masked loads are routed to normal PTX loads, with one difference: a
`#pragma "used_bytes_mask"` is emitted before the load instruction
(https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask).
To accomplish this, a new operand is added to every NVPTXISD Load type
representing this mask.
* `NVPTXISelLowering.h/.cpp` - Masked loads are converted into normal
NVPTXISD loads with a mask operand in two ways. 1) In type legalization
through replaceLoadVector, which is the normal path, and 2) through
LowerMLOAD, to handle the legal vector types
(v2f16/v2bf16/v2i16/v4i8/v2f32) that will not be type legalized. Both
share the same convertMLOADToLoadWithUsedBytesMask helper. Both default
this operand to UINT32_MAX, representing all bytes on. For the latter,
we need a new `NVPTXISD::MLoadV1` type to represent that edge case
because we cannot put the used bytes mask operand on a generic
LoadSDNode.
* `NVPTXISelDAGToDAG.cpp` - Extract used bytes mask from loads, add them
to created machine instructions.
* `NVPTXInstPrinter.h/.cpp` - Print the pragma when the used bytes mask
isn't all ones.
* `NVPTXForwardParams.cpp`, `NVPTXReplaceImageHandles.cpp` - Update
manual indexing of load operands to account for new operand.
* `NVPTXInsrtInfo.td`, `NVPTXIntrinsics.td` - Add the used bytes mask to
the MI definitions.
* `NVPTXTagInvariantLoads.cpp` - Ensure that masked loads also get
tagged as invariant.

Some generic changes that are needed:
* `LegalizeVectorTypes.cpp` - Ensure flags are preserved when splitting
masked loads.
* `SelectionDAGBuilder.cpp` - Preserve `MD_invariant_load` on masked
load SDNode creation
kcloudy0717 pushed a commit to kcloudy0717/llvm-project that referenced this pull request Dec 4, 2025
This backend support will allow the LoadStoreVectorizer, in certain
cases, to fill in gaps when creating load/store vectors and generate
LLVM masked load/stores
(https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics). To
accomplish this, changes are separated into two parts. This first part
has the backend lowering and TTI changes, and a follow up PR will have
the LSV generate these intrinsics:
llvm#159388.

In this backend change, Masked Loads get lowered to PTX with `#pragma
"used_bytes_mask" [mask];`
(https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask).
And Masked Stores get lowered to PTX using the new sink symbol syntax
(https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st).

# TTI Changes
TTI changes are needed because NVPTX only supports masked loads/stores
with _constant_ masks. `ScalarizeMaskedMemIntrin.cpp` is adjusted to
check that the mask is constant and pass that result into the TTI check.
Behavior shouldn't change for non-NVPTX targets, which do not care
whether the mask is variable or constant when determining legality, but
all TTI files that implement these API need to be updated.

# Masked store lowering implementation details
If the masked stores make it to the NVPTX backend without being
scalarized, they are handled by the following:
* `NVPTXISelLowering.cpp` - Sets up a custom operation action and
handles it in lowerMSTORE. Similar handling to normal store vectors,
except we read the mask and place a sentinel register `$noreg` in each
position where the mask reads as false.

For example, 
```
t10: v8i1 = BUILD_VECTOR Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>, Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>
t11: ch = masked_store<(store unknown-size into %ir.lsr.iv28, align 32, addrspace 1)> t5:1, t5, t7, undef:i64, t10

->

STV_i32_v8 killed %13:int32regs, $noreg, $noreg, killed %16:int32regs, killed %17:int32regs, $noreg, $noreg, killed %20:int32regs, 0, 0, 1, 8, 0, 32, %4:int64regs, 0, debug-location !18 :: (store unknown-size into %ir.lsr.iv28, align 32, addrspace 1);

```

* `NVPTXInstInfo.td` - changes the definition of store vectors to allow
for a mix of sink symbols and registers.
* `NVPXInstPrinter.h/.cpp` - Handles the `$noreg` case by printing "_".

# Masked load lowering implementation details
Masked loads are routed to normal PTX loads, with one difference: a
`#pragma "used_bytes_mask"` is emitted before the load instruction
(https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask).
To accomplish this, a new operand is added to every NVPTXISD Load type
representing this mask.
* `NVPTXISelLowering.h/.cpp` - Masked loads are converted into normal
NVPTXISD loads with a mask operand in two ways. 1) In type legalization
through replaceLoadVector, which is the normal path, and 2) through
LowerMLOAD, to handle the legal vector types
(v2f16/v2bf16/v2i16/v4i8/v2f32) that will not be type legalized. Both
share the same convertMLOADToLoadWithUsedBytesMask helper. Both default
this operand to UINT32_MAX, representing all bytes on. For the latter,
we need a new `NVPTXISD::MLoadV1` type to represent that edge case
because we cannot put the used bytes mask operand on a generic
LoadSDNode.
* `NVPTXISelDAGToDAG.cpp` - Extract used bytes mask from loads, add them
to created machine instructions.
* `NVPTXInstPrinter.h/.cpp` - Print the pragma when the used bytes mask
isn't all ones.
* `NVPTXForwardParams.cpp`, `NVPTXReplaceImageHandles.cpp` - Update
manual indexing of load operands to account for new operand.
* `NVPTXInsrtInfo.td`, `NVPTXIntrinsics.td` - Add the used bytes mask to
the MI definitions.
* `NVPTXTagInvariantLoads.cpp` - Ensure that masked loads also get
tagged as invariant.

Some generic changes that are needed:
* `LegalizeVectorTypes.cpp` - Ensure flags are preserved when splitting
masked loads.
* `SelectionDAGBuilder.cpp` - Preserve `MD_invariant_load` on masked
load SDNode creation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants