Skip to content

Conversation

@grypp
Copy link
Member

@grypp grypp commented Aug 19, 2025

Key Features

  1. Multiple SSA returns – no struct packing/unpacking required.
  2. Automatic struct unpacking – values are directly usable.
  3. Readable register mapping
    • {$rwN} → read-write
    • {$roN} → read-only
    • {$woN} → write-only
  4. Full read-write support (+ modifier).
  5. Simplified operand specification – avoids cryptic "=r,=r,=f,=f,f,f,0,1" constraints.
  6. Predicate support: PTX @p predication support

IR Example:

%wo0, %wo1 = nvvm.inline_ptx """
 .reg .pred p;
 setp.ge.s32 p,   {$r0}, {$r1};
 selp.s32 {$rw0}, {$r0}, {$r1}, p;
 selp.s32 {$rw1}, {$r0}, {$r1}, p;
 selp.s32 {$w0},  {$r0}, {$r1}, p;
 selp.s32 {$w1},  {$r0}, {$r1}, p;
""" ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -> f32, f32

After lowering

 %0 = llvm.inline_asm has_side_effects asm_dialect = att
 "{
                              .reg .pred p;\
                              setp.ge.s32 p, $4, $5;   \
                              selp.s32   $0, $4, $5, p;\
                              selp.s32   $1, $4, $5, p;\
                              selp.s32   $2, $4, $5, p;\
                              selp.s32   $3, $4, $5, p;\
   }"
   "=r,=r,=f,=f,f,f,0,1"
   %c500_i32, %c400_i32, %cst, %cst_0
   : (i32, i32, f32, f32)
   -> !llvm.struct<(i32, i32, f32, f32)>

 %1 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
 %2 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
 %3 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
 %4 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>

 // Unpacked result from nvvm.inline_ptx
 %5 = arith.addi %1, %2 : i32
 // read only
 %6 = arith.addf %cst, %cst_0 : f32
 // write only
 %7 = arith.addf %3, %4 : f32

@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Guray Ozen (grypp)

Changes

Key Features

  1. Multiple SSA returns – no struct packing/unpacking required.
  2. Automatic struct unpacking – values are directly usable.
  3. Readable register mapping
    • {$rwN} → read-write
    • {$roN} → read-only
    • {$woN} → write-only
  4. Full read-write support (+ modifier).
  5. Simplified operand specification – avoids cryptic "=r,=r,=f,=f,f,f,0,1" constraints.
  6. Predicate support: PTX @p predication support

IR Example:

%wo0, %wo1 = nvvm.inline_ptx """
 .reg .pred p;
 setp.ge.s32 p,   {$r0}, {$r1};
 selp.s32 {$rw0}, {$r0}, {$r1}, p;
 selp.s32 {$rw1}, {$r0}, {$r1}, p;
 selp.s32 {$w0},  {$r0}, {$r1}, p;
 selp.s32 {$w1},  {$r0}, {$r1}, p;
""" ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -&gt; f32, f32

After lowering

 %0 = llvm.inline_asm has_side_effects asm_dialect = att
 "{
                              .reg .pred p;\
                              setp.ge.s32 p, $4, $5;   \
                              selp.s32   $0, $4, $5, p;\
                              selp.s32   $1, $4, $5, p;\
                              selp.s32   $2, $4, $5, p;\
                              selp.s32   $3, $4, $5, p;\
   }"
   "=r,=r,=f,=f,f,f,0,1"
   %c500_i32, %c400_i32, %cst, %cst_0
   : (i32, i32, f32, f32)
   -&gt; !llvm.struct&lt;(i32, i32, f32, f32)&gt;

 %1 = llvm.extractvalue %0 : !llvm.struct&lt;(i32, i32, f32, f32)&gt;
 %2 = llvm.extractvalue %0 : !llvm.struct&lt;(i32, i32, f32, f32)&gt;
 %3 = llvm.extractvalue %0 : !llvm.struct&lt;(i32, i32, f32, f32)&gt;
 %4 = llvm.extractvalue %0 : !llvm.struct&lt;(i32, i32, f32, f32)&gt;

 // Unpacked result from nvvm.inline_ptx
 %5 = arith.addi %1, %2 : i32
 // read only
 %6 = arith.addf %cst, %cst_0 : f32
 // write only
 %7 = arith.addf %3, %4 : f32

Patch is 29.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154358.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h (+9-5)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td (+8-5)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+14-8)
  • (modified) mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp (+2-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp (+289-32)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+18-1)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+54-6)
  • (modified) mlir/test/python/dialects/nvvm.py (+41)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index 3e3fcd7d1fb82..99b1d9709e3e1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -26,9 +26,9 @@ namespace NVVM {
 enum class PTXRegisterMod {
   /// Read register with no modifier
   Read = 0,
-  /// Read register with '+' modifier
+  /// Read register with '=' modifier
   Write = 2,
-  /// Read register with '=' modifier.
+  /// Read register with '+' modifier.
   /// Note that, this is not natively supported by LLVM, but it is possible to
   /// set read and write for the same operand.
   ReadWrite = 1,
@@ -67,13 +67,17 @@ class PtxBuilder {
   SmallVector<Value> ptxOperands;
   // Register constraints (read, write, readwrite) and register data types
   std::string registerConstraints;
-
+  // Modifiers
+  SmallVector<PTXRegisterMod> registerModifiers;
   bool hasResult = false;
+  bool needsManualMapping = false;
 
 public:
   /// Single constructor that only initializes members.
-  PtxBuilder(Operation *op, PatternRewriter &rewriter)
-      : interfaceOp(op), rewriter(rewriter) {}
+  PtxBuilder(Operation *op, PatternRewriter &rewriter,
+             bool needsManualMapping = false)
+      : interfaceOp(op), rewriter(rewriter),
+        needsManualMapping(needsManualMapping) {}
 
   /// Add an operand with the read/write input type.
   void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
index e98b94b5b3052..8e36749cdb361 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
@@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
             following this order:
              1) Adds results 
              2) Adds operands 
-             3) Adds attributes             
+             3) Adds attributes
+             Returns true if it does the mapping manually
           }],
-         /*retType=*/"void",
+         /*retType=*/"bool",
          /*methodName=*/"getAsmValues",
          /*args=*/(ins "::mlir::RewriterBase &":$rewriter, 
-         "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
+         "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues         
+         ),
          /*methodBody=*/"",
          /*defaultImpl=*/ [{         
            mlir::Operation* op = $_op;
            
            // Step 1. Add results
-           for (auto val : op->getResults()) 
-            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
+            for (auto val : op->getResults()) 
+              asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
 
            // Step 2. Add operands
            for (auto val : op->getOperands()) 
@@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
              asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
              }
            }
+           return false; // No needs manual mapping
          }]
        >
   ];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f9cd58de8915f..786d42cf15666 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -315,16 +315,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
   }];
 
   let arguments = (ins Variadic<AnyType>:$readOnlyArgs, 
+                       Variadic<AnyType>:$readWriteArgs, 
                        StrAttr:$ptxCode,
                        PtxPredicate:$predicate);
                       
   let results = (outs Variadic<AnyType>:$writeOnlyArgs);
-
-  let assemblyFormat = [{ 
-    $ptxCode `(` $readOnlyArgs `)` 
-    (`,` `predicate` `=` $predicate^)? attr-dict 
-    `:` type(operands) 
-    (`->` type($writeOnlyArgs)^)?
+  
+  let assemblyFormat = [{
+    $ptxCode
+    ( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )?
+    ( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )?
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict
+    ( `->` type($writeOnlyArgs)^ )?
   }];
   
   let extraClassDefinition = [{
@@ -333,6 +336,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
       return std::string(ptxInstStr.data());
     }
   }];
+
+  let extraClassDeclaration = [{
+    bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -3027,8 +3034,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
-    void getAsmValues(RewriterBase &rewriter, 
-        llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
+    bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
   }];
 }
 
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index e0144bff4d371..c67ec3642f121 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -57,9 +57,9 @@ struct PtxLowering
 
     SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
     LDBG() << op.getPtx();
-    PtxBuilder generator(op, rewriter);
 
-    op.getAsmValues(rewriter, asmValues);
+    bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
+    PtxBuilder generator(op, rewriter, needsManualMapping);
     for (auto &[asmValue, modifier] : asmValues) {
       LDBG() << asmValue << "\t Modifier : " << modifier;
       generator.insertValue(asmValue, modifier);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index e004d5f64733e..3cad9d3bd16e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,6 +12,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Regex.h"
 
 #define DEBUG_TYPE "ptx-builder"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -59,12 +62,28 @@ static char getRegisterType(Value v) {
   return getRegisterType(v.getType());
 }
 
+/// Extract every elements of a struct value.
+static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
+                                                Location loc, Value agg) {
+  auto structTy = cast<LLVM::LLVMStructType>(agg.getType());
+  SmallVector<Value> elems;
+  elems.reserve(structTy.getBody().size());
+  for (auto [i, t] : llvm::enumerate(structTy.getBody())) {
+    (void)t;
+    Value e = LLVM::ExtractValueOp::create(rewriter, loc, agg, i);
+    elems.push_back(e);
+  }
+  return elems;
+}
+
 void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
-  LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
+  LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n");
+  registerModifiers.push_back(itype);
+
   auto getModifier = [&]() -> const char * {
     if (itype == PTXRegisterMod::ReadWrite) {
-      assert(false && "Read-Write modifier is not supported. Try setting the "
-                      "same value as Write and Read separately.");
+      // "Read-Write modifier is not supported
+      // Interface canonicalize it later
       return "+";
     }
     if (itype == PTXRegisterMod::Write) {
@@ -72,6 +91,7 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
     }
     return "";
   };
+
   auto addValue = [&](Value v) {
     if (itype == PTXRegisterMod::Read) {
       ptxOperands.push_back(v);
@@ -108,38 +128,222 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
 }
 
 /// Check if the operation needs to pack and unpack results.
-static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
-  return interfaceOp->getNumResults() > 1;
+static bool
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+                SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
+  if (needsManualMapping)
+    return false;
+  const unsigned writeOnly = interfaceOp->getNumResults();
+  const unsigned readWrite =
+      llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
+        return m == PTXRegisterMod::ReadWrite;
+      });
+  return (writeOnly + readWrite) > 1;
 }
 
 /// Pack the result types of the interface operation.
 /// If the operation has multiple results, it packs them into a struct
 /// type. Otherwise, it returns the original result types.
-static SmallVector<Type> packResultTypes(MLIRContext *ctx,
-                                         BasicPtxBuilderInterface interfaceOp) {
-  TypeRange results = interfaceOp->getResultTypes();
+static SmallVector<Type>
+packResultTypes(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+                SmallVectorImpl<PTXRegisterMod> &registerModifiers,
+                SmallVectorImpl<Value> &ptxOperands) {
+  MLIRContext *ctx = interfaceOp->getContext();
+  TypeRange resultRange = interfaceOp->getResultTypes();
+
+  if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+    // Single value path:
+    if (interfaceOp->getResults().size() == 1)
+      return SmallVector<Type>{resultRange.front()};
+
+    // No declared results: if there is an RW, forward its type.
+    for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+      if (m == PTXRegisterMod::ReadWrite)
+        return SmallVector<Type>{v.getType()};
+  }
+
+  SmallVector<Type> packed;
+  for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+    if (m == PTXRegisterMod::ReadWrite)
+      packed.push_back(v.getType());
+  for (Type t : resultRange)
+    packed.push_back(t);
+
+  if (packed.empty())
+    return {};
+
+  auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
+  return SmallVector<Type>{sTy};
+}
+
+/// Canonicalize the register constraints:
+///  - Turn every "+X" into "=X"
+///  - Append (at the very end) the 0-based indices of tokens that were "+X"
+/// Examples:
+///  "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
+///  "+f,+f,+r,=r,=r"     -> "=f,=f,=r,=r,=r,0,1,2"
+static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
+  SmallVector<llvm::StringRef> toks;
+  SmallVector<std::string> out;
+  SmallVector<unsigned> plusIdx;
+
+  csv.split(toks, ',');
+  out.reserve(toks.size() + 8);
+
+  for (unsigned i = 0, e = toks.size(); i < e; ++i) {
+    StringRef t = toks[i].trim();
+    if (t.consume_front("+")) {
+      plusIdx.push_back(i);
+      out.push_back(("=" + t).str());
+    } else {
+      out.push_back(t.str());
+    }
+  }
+
+  // Append indices of original "+X" tokens.
+  for (unsigned idx : plusIdx)
+    out.push_back(std::to_string(idx));
+
+  // Join back to CSV.
+  std::string result;
+  result.reserve(csv.size() + plusIdx.size() * 2);
+  llvm::raw_string_ostream os(result);
+  for (size_t i = 0; i < out.size(); ++i) {
+    if (i)
+      os << ',';
+    os << out[i];
+  }
+  return os.str();
+}
+
+constexpr llvm::StringLiteral kReadWrite{"rw"};
+constexpr llvm::StringLiteral kWriteOnly{"w"};
+constexpr llvm::StringLiteral kReadOnly{"r"};
+
+/// Rewrites placeholders of the form `{$rN}`, `{$wN}`, `{$rwN}` in `asmText`
+/// to compact `$K` indices where all `rw*` come first (ascending N), then `w*`,
+/// then `r*`. Duplicates are de-duplicated when assigning numbers.
+/// Unknown text is preserved verbatim.
+///
+/// Example Input:
+/// "{
+///       reg .pred p;
+///       setp.ge.s32 p,   {$r0}, {$r1};"
+///       selp.s32 {$rw0}, {$r0}, {$r1}, p;
+///       selp.s32 {$rw1}, {$r0}, {$r1}, p;
+///       selp.s32 {$w0},  {$r0}, {$r1}, p;
+///       selp.s32 {$w1},  {$r0}, {$r1}, p;
+/// }\n"
+/// Example Output:
+/// "{
+///       reg .pred p;
+///       setp.ge.s32 p, $4, $5;"
+///       selp.s32 $0,   $4, $5, p;
+///       selp.s32 $1,   $4, $5, p;
+///       selp.s32 $2,   $4, $5, p;
+///       selp.s32 $3,   $4, $5, p;
+/// }\n"
+static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
+  // Match {$rwN}, {$wN}, {$rN}
+  llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})", kReadWrite,
+                               kWriteOnly, kReadOnly)
+                     .str());
+
+  llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
+  llvm::SmallVector<unsigned> rwNums, wNums, rNums;
+
+  {
+    StringRef rest = asmText;
+    SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
+    while (!rest.empty() && rx.match(rest, &m)) {
+      unsigned num = 0;
+      (void)m[2].getAsInteger(10, num);
+
+      if (m[1].equals_insensitive(kReadWrite)) {
+        if (seenRW.insert(num).second)
+          rwNums.push_back(num);
+      } else if (m[1].equals_insensitive(kWriteOnly)) {
+        if (seenW.insert(num).second)
+          wNums.push_back(num);
+      } else {
+        if (seenR.insert(num).second)
+          rNums.push_back(num);
+      }
+
+      const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+      rest = rest.drop_front(advance);
+    }
+  }
+
+  llvm::sort(rwNums);
+  llvm::sort(wNums);
+  llvm::sort(rNums);
+
+  llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
+  unsigned nextId = 0;
+  for (unsigned n : rwNums)
+    rwMap[n] = nextId++;
+  for (unsigned n : wNums)
+    wMap[n] = nextId++;
+  for (unsigned n : rNums)
+    rMap[n] = nextId++;
+
+  std::string out;
+  out.reserve(asmText.size());
 
-  if (!needsPackUnpack(interfaceOp))
-    return llvm::to_vector<1>(results);
+  size_t prev = 0;
+  StringRef rest = asmText;
+  SmallVector<StringRef, 3> m;
+  while (!rest.empty() && rx.match(rest, &m)) {
+    // Compute absolute match bounds in the original buffer.
+    size_t absStart = (size_t)(m[0].data() - asmText.data());
+    size_t absEnd = absStart + m[0].size();
 
-  SmallVector<mlir::Type> elems(results.begin(), results.end());
-  auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
-  return {sTy};
+    // Emit text before the match.
+    out.append(asmText.data() + prev, asmText.data() + absStart);
+
+    // Emit compact $K
+    unsigned num = 0;
+    (void)m[2].getAsInteger(10, num);
+    unsigned id = 0;
+    if (m[1].equals_insensitive(kReadWrite))
+      id = rwMap.lookup(num);
+    else if (m[1].equals_insensitive(kWriteOnly))
+      id = wMap.lookup(num);
+    else
+      id = rMap.lookup(num);
+
+    out.push_back('$');
+    out += std::to_string(id);
+
+    prev = absEnd;
+
+    // Advance search window.
+    const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+    rest = rest.drop_front(advance);
+  }
+
+  // Tail.
+  out.append(asmText.data() + prev, asmText.data() + asmText.size());
+  return out;
 }
 
 LLVM::InlineAsmOp PtxBuilder::build() {
-  MLIRContext *ctx = interfaceOp->getContext();
   auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
                                                   LLVM::AsmDialect::AD_ATT);
 
-  SmallVector<Type> resultTypes = packResultTypes(ctx, interfaceOp);
+  SmallVector<Type> resultTypes = packResultTypes(
+      interfaceOp, needsManualMapping, registerModifiers, ptxOperands);
 
   // Remove the last comma from the constraints string.
   if (!registerConstraints.empty() &&
       registerConstraints[registerConstraints.size() - 1] == ',')
     registerConstraints.pop_back();
+  registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
 
   std::string ptxInstruction = interfaceOp.getPtx();
+  if (!needsManualMapping)
+    ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
 
   // Add the predicate to the asm string.
   if (interfaceOp.getPredicate().has_value() &&
@@ -169,33 +373,86 @@ void PtxBuilder::buildAndReplaceOp() {
   LLVM::InlineAsmOp inlineAsmOp = build();
   LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
 
-  // Case 1: no result
-  if (inlineAsmOp->getNumResults() == 0) {
+  // Case 0: no result at all → just erase wrapper op.
+  if (!hasResult) {
     rewriter.eraseOp(interfaceOp);
     return;
   }
 
-  // Case 2: single result, forward it directly
-  if (!needsPackUnpack(interfaceOp)) {
+  if (needsManualMapping) {
     rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
     return;
   }
 
-  // Case 3: multiple results were packed; unpack the struct.
-  assert(mlir::LLVM::LLVMStructType::classof(
-             inlineAsmOp.getResultTypes().front()) &&
-         "Expected result type to be LLVMStructType when unpacking multiple "
-         "results");
-  auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
-      inlineAsmOp.getResultTypes().front());
+  // Case 1: Simple path, return single scalar
+  if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+    if (inlineAsmOp->getNumResults() > 0) {
+      rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+    } else {
+      // RW-only case with no declared results: forward the RW value.
+      SmallVector<Value> results;
+      for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+        if (m == PTXRegisterMod::ReadWrite) {
+          results.push_back(v);
+          break;
+        }
+      rewriter.replaceOp(interfaceOp, results);
+    }
+    return;
+  }
+
+  const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
+    return m == PTXRegisterMod::ReadWrite;
+  });
 
-  SmallVector<mlir::Value> unpacked;
+  // All multi-value paths produce a single struct result we need to unpack.
+  assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
+         "expected struct return for multi-result inline asm");
   Value structVal = inlineAsmOp.getResult(0);
-  for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
-    Value unpackedValue = LLVM::ExtractValueOp::create(
-        rewriter, interfaceOp->getLoc(), structVal, idx);
-    unpacked.push_back(unpackedValue);
+  SmallVector<Value> unpacked =
+      extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
+
+  // Case 2: only declared results (no RW): replace the op with all unpacked.
+  if (!hasRW && interfaceOp->getResults().size() > 0) {
+    rewriter.replaceOp(interfaceOp, unpacked);
+    return;
+  }
+
+  // Case 3: RW-only (no declared results): update RW uses and erase wrapper.
+  if (hasRW && interfaceOp->getResults().size() == 0) {
+    unsigned idx = 0;
+    for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+      if (m != PTXRegisterMod::ReadWrite)
+        continue;
+      Value repl = unpacked[idx++];
+      v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+        Operation *owner = use.getOwner();
+        return owner != interfaceOp && owner != inlineAsmOp;
+      });
+    }
+    rewriter.eraseOp(interfaceOp);
+    return;
   }
 
-  rewriter.replaceOp(interfaceOp, unpacked);
+  // Case 4: mixed (RW + declared results).
+  {
+    // First rewrite RW operands in place.
+    unsigned idx = 0;
+    for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+      if (m != PTXRegisterMod::ReadWrite)
+        continue;
+      Value repl = unpacked[idx++];
+      v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+        Operation *owner = use.getOwner();
+        return owner != interfaceOp && owner != inlineAsmOp;
+      });
+    }
+    // The remaining unpacked values correspond to the declared results.
+    SmallVector<Value> tail;
+    tail.reserve(unpacked.size() - idx);
+    for (unsigned i = idx, e = unpacked.size(); i < e; ++i)
+      tail.push_back(unpacked[i]);
+
+    rewriter.replaceOp(interfaceOp, tail);
+  }
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index dbcc738b4419f..ae9134458095f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1123,7 +1123,7 @@ std::string NVVM::WgmmaMma...
[truncated]

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really cool, I don't have concerns scanning through it, hopefully someone can double check all the logic as well!

@joker-eph joker-eph requested a review from Copilot August 19, 2025 15:51
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the NVVM inline_ptx operation to better support read-write operands and simplifies the inline assembly generation. The key improvement is replacing the manual struct packing/unpacking with automatic handling and introducing readable register mapping with named placeholders.

Key changes:

  • Adds read-write (rw) operand support with automatic struct unpacking
  • Introduces named register placeholders ({$rN}, {$wN}, {$rwN}) that are automatically mapped to numbered registers
  • Simplifies the inline assembly generation by avoiding manual constraint specification

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
mlir/test/python/dialects/nvvm.py Adds Python test for new inline_ptx syntax with rw operands
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir Updates test cases to use new placeholder syntax and adds read-write operand tests
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp Implements getAsmValues for InlinePtxOp to handle read-write operands
mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp Major refactoring to support automatic placeholder rewriting and register constraint canonicalization
mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp Updates to handle manual mapping flag from getAsmValues
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td Updates inline_ptx operation definition to include readWriteArgs
mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td Changes getAsmValues interface to return bool for manual mapping indication
mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h Updates PtxBuilder constructor and adds manual mapping support

@grypp
Copy link
Member Author

grypp commented Aug 20, 2025

This looks really cool, I don't have concerns scanning through it, hopefully someone can double check all the logic as well!

The only nontrivial part is the register numbering. For example, in the snippet below the mapping should be:
rw0=0, rw1=1, w0=2, w1=3, r0=4, r1=5.

 setp.ge.s32 p,   {$r0},  {$r1};
 selp.s32    {$rw0}, {$r0}, {$r1}, p;
 selp.s32    {$rw1}, {$r0}, {$r1}, p;
 selp.s32    {$w0},  {$r0}, {$r1}, p;
 selp.s32    {$w1},  {$r0}, {$r1}, p;

grypp added 5 commits August 21, 2025 09:23
Key Features
1. Multiple SSA returns – no struct packing/unpacking required.
2. Automatic struct unpacking – values are directly usable.
3. Readable register mapping
    * {$rwN} → read-write
    * {$roN} → read-only
    * {$woN} → write-only
4. Full read-write support (+ modifier).
5. Simplified operand specification – avoids cryptic "=r,=r,=f,=f,f,f,0,1" constraints.
6. Predicate support: PTX @p predication support

IR Example:
```
%wo0, %wo1 = nvvm.inline_ptx """
 .reg .pred p;
 setp.ge.s32 p,   {$r0}, {$r1};
 selp.s32 {$rw0}, {$r0}, {$r1}, p;
 selp.s32 {$rw1}, {$r0}, {$r1}, p;
 selp.s32 {$w0},  {$r0}, {$r1}, p;
 selp.s32 {$w1},  {$r0}, {$r1}, p;
""" ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -> f32, f32
```

After lowering
```
 %0 = llvm.inline_asm has_side_effects asm_dialect = att
 "{
                              .reg .pred p;\
                              setp.ge.s32 p, $4, $5;   \
                              selp.s32   $0, $4, $5, p;\
                              selp.s32   $1, $4, $5, p;\
                              selp.s32   $2, $4, $5, p;\
                              selp.s32   $3, $4, $5, p;\
   }"
   "=r,=r,=f,=f,f,f,0,1"
   %c500_i32, %c400_i32, %cst, %cst_0
   : (i32, i32, f32, f32)
   -> !llvm.struct<(i32, i32, f32, f32)>

 %1 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
 %2 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
 %3 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
 %4 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>

 // Unpacked result from nvvm.inline_ptx
 %5 = arith.addi %1, %2 : i32
 // read only
 %6 = arith.addf %cst, %cst_0 : f32
 // write only
 %7 = arith.addf %3, %4 : f32
```
@grypp grypp force-pushed the nvvm-ptx-readwrite branch from 341f7f2 to e9c0aa2 Compare August 21, 2025 09:26

// CHECK-LABEL: @inline_ptx_multi_rw_r(
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 {
Copy link
Contributor

@durga4github durga4github Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _r here in the function name is tricky. I think you meant it for return values.

So, let us name it: multi_return_rw or multi_rw_return

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for a few queries.

@grypp grypp merged commit 5c36fb3 into llvm:main Aug 21, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants