-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][NVVM] Improve inline_ptx, add readwrite support #154358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Guray Ozen (grypp) ChangesKey Features
IR Example: After lowering Patch is 29.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154358.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index 3e3fcd7d1fb82..99b1d9709e3e1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -26,9 +26,9 @@ namespace NVVM {
enum class PTXRegisterMod {
/// Read register with no modifier
Read = 0,
- /// Read register with '+' modifier
+ /// Read register with '=' modifier
Write = 2,
- /// Read register with '=' modifier.
+ /// Read register with '+' modifier.
/// Note that, this is not natively supported by LLVM, but it is possible to
/// set read and write for the same operand.
ReadWrite = 1,
@@ -67,13 +67,17 @@ class PtxBuilder {
SmallVector<Value> ptxOperands;
// Register constraints (read, write, readwrite) and register data types
std::string registerConstraints;
-
+ // Modifiers
+ SmallVector<PTXRegisterMod> registerModifiers;
bool hasResult = false;
+ bool needsManualMapping = false;
public:
/// Single constructor that only initializes members.
- PtxBuilder(Operation *op, PatternRewriter &rewriter)
- : interfaceOp(op), rewriter(rewriter) {}
+ PtxBuilder(Operation *op, PatternRewriter &rewriter,
+ bool needsManualMapping = false)
+ : interfaceOp(op), rewriter(rewriter),
+ needsManualMapping(needsManualMapping) {}
/// Add an operand with the read/write input type.
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
index e98b94b5b3052..8e36749cdb361 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
@@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
following this order:
1) Adds results
2) Adds operands
- 3) Adds attributes
+ 3) Adds attributes
+ Returns true if it does the mapping manually
}],
- /*retType=*/"void",
+ /*retType=*/"bool",
/*methodName=*/"getAsmValues",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
+ "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues
+ ),
/*methodBody=*/"",
/*defaultImpl=*/ [{
mlir::Operation* op = $_op;
// Step 1. Add results
- for (auto val : op->getResults())
- asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
+ for (auto val : op->getResults())
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
// Step 2. Add operands
for (auto val : op->getOperands())
@@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
}
}
+ return false; // No needs manual mapping
}]
>
];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f9cd58de8915f..786d42cf15666 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -315,16 +315,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
}];
let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
+ Variadic<AnyType>:$readWriteArgs,
StrAttr:$ptxCode,
PtxPredicate:$predicate);
let results = (outs Variadic<AnyType>:$writeOnlyArgs);
-
- let assemblyFormat = [{
- $ptxCode `(` $readOnlyArgs `)`
- (`,` `predicate` `=` $predicate^)? attr-dict
- `:` type(operands)
- (`->` type($writeOnlyArgs)^)?
+
+ let assemblyFormat = [{
+ $ptxCode
+ ( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )?
+ ( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )?
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict
+ ( `->` type($writeOnlyArgs)^ )?
}];
let extraClassDefinition = [{
@@ -333,6 +336,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
return std::string(ptxInstStr.data());
}
}];
+
+ let extraClassDeclaration = [{
+ bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
+ }];
}
//===----------------------------------------------------------------------===//
@@ -3027,8 +3034,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
let hasVerifier = 1;
let extraClassDeclaration = [{
- void getAsmValues(RewriterBase &rewriter,
- llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
+ bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
}];
}
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index e0144bff4d371..c67ec3642f121 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -57,9 +57,9 @@ struct PtxLowering
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
LDBG() << op.getPtx();
- PtxBuilder generator(op, rewriter);
- op.getAsmValues(rewriter, asmValues);
+ bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
+ PtxBuilder generator(op, rewriter, needsManualMapping);
for (auto &[asmValue, modifier] : asmValues) {
LDBG() << asmValue << "\t Modifier : " << modifier;
generator.insertValue(asmValue, modifier);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index e004d5f64733e..3cad9d3bd16e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,6 +12,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Regex.h"
#define DEBUG_TYPE "ptx-builder"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -59,12 +62,28 @@ static char getRegisterType(Value v) {
return getRegisterType(v.getType());
}
+/// Extract every elements of a struct value.
+static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
+ Location loc, Value agg) {
+ auto structTy = cast<LLVM::LLVMStructType>(agg.getType());
+ SmallVector<Value> elems;
+ elems.reserve(structTy.getBody().size());
+ for (auto [i, t] : llvm::enumerate(structTy.getBody())) {
+ (void)t;
+ Value e = LLVM::ExtractValueOp::create(rewriter, loc, agg, i);
+ elems.push_back(e);
+ }
+ return elems;
+}
+
void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
- LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
+ LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n");
+ registerModifiers.push_back(itype);
+
auto getModifier = [&]() -> const char * {
if (itype == PTXRegisterMod::ReadWrite) {
- assert(false && "Read-Write modifier is not supported. Try setting the "
- "same value as Write and Read separately.");
+ // "Read-Write modifier is not supported
+ // Interface canonicalize it later
return "+";
}
if (itype == PTXRegisterMod::Write) {
@@ -72,6 +91,7 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
return "";
};
+
auto addValue = [&](Value v) {
if (itype == PTXRegisterMod::Read) {
ptxOperands.push_back(v);
@@ -108,38 +128,222 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
/// Check if the operation needs to pack and unpack results.
-static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
- return interfaceOp->getNumResults() > 1;
+static bool
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+ SmallVectorImpl<PTXRegisterMod> ®isterModifiers) {
+ if (needsManualMapping)
+ return false;
+ const unsigned writeOnly = interfaceOp->getNumResults();
+ const unsigned readWrite =
+ llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
+ return (writeOnly + readWrite) > 1;
}
/// Pack the result types of the interface operation.
/// If the operation has multiple results, it packs them into a struct
/// type. Otherwise, it returns the original result types.
-static SmallVector<Type> packResultTypes(MLIRContext *ctx,
- BasicPtxBuilderInterface interfaceOp) {
- TypeRange results = interfaceOp->getResultTypes();
+static SmallVector<Type>
+packResultTypes(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+ SmallVectorImpl<PTXRegisterMod> ®isterModifiers,
+ SmallVectorImpl<Value> &ptxOperands) {
+ MLIRContext *ctx = interfaceOp->getContext();
+ TypeRange resultRange = interfaceOp->getResultTypes();
+
+ if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+ // Single value path:
+ if (interfaceOp->getResults().size() == 1)
+ return SmallVector<Type>{resultRange.front()};
+
+ // No declared results: if there is an RW, forward its type.
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ return SmallVector<Type>{v.getType()};
+ }
+
+ SmallVector<Type> packed;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ packed.push_back(v.getType());
+ for (Type t : resultRange)
+ packed.push_back(t);
+
+ if (packed.empty())
+ return {};
+
+ auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
+ return SmallVector<Type>{sTy};
+}
+
+/// Canonicalize the register constraints:
+/// - Turn every "+X" into "=X"
+/// - Append (at the very end) the 0-based indices of tokens that were "+X"
+/// Examples:
+/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
+/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2"
+static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
+ SmallVector<llvm::StringRef> toks;
+ SmallVector<std::string> out;
+ SmallVector<unsigned> plusIdx;
+
+ csv.split(toks, ',');
+ out.reserve(toks.size() + 8);
+
+ for (unsigned i = 0, e = toks.size(); i < e; ++i) {
+ StringRef t = toks[i].trim();
+ if (t.consume_front("+")) {
+ plusIdx.push_back(i);
+ out.push_back(("=" + t).str());
+ } else {
+ out.push_back(t.str());
+ }
+ }
+
+ // Append indices of original "+X" tokens.
+ for (unsigned idx : plusIdx)
+ out.push_back(std::to_string(idx));
+
+ // Join back to CSV.
+ std::string result;
+ result.reserve(csv.size() + plusIdx.size() * 2);
+ llvm::raw_string_ostream os(result);
+ for (size_t i = 0; i < out.size(); ++i) {
+ if (i)
+ os << ',';
+ os << out[i];
+ }
+ return os.str();
+}
+
+constexpr llvm::StringLiteral kReadWrite{"rw"};
+constexpr llvm::StringLiteral kWriteOnly{"w"};
+constexpr llvm::StringLiteral kReadOnly{"r"};
+
+/// Rewrites placeholders of the form `{$rN}`, `{$wN}`, `{$rwN}` in `asmText`
+/// to compact `$K` indices where all `rw*` come first (ascending N), then `w*`,
+/// then `r*`. Duplicates are de-duplicated when assigning numbers.
+/// Unknown text is preserved verbatim.
+///
+/// Example Input:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, {$r0}, {$r1};"
+/// selp.s32 {$rw0}, {$r0}, {$r1}, p;
+/// selp.s32 {$rw1}, {$r0}, {$r1}, p;
+/// selp.s32 {$w0}, {$r0}, {$r1}, p;
+/// selp.s32 {$w1}, {$r0}, {$r1}, p;
+/// }\n"
+/// Example Output:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, $4, $5;"
+/// selp.s32 $0, $4, $5, p;
+/// selp.s32 $1, $4, $5, p;
+/// selp.s32 $2, $4, $5, p;
+/// selp.s32 $3, $4, $5, p;
+/// }\n"
+static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
+ // Match {$rwN}, {$wN}, {$rN}
+ llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})", kReadWrite,
+ kWriteOnly, kReadOnly)
+ .str());
+
+ llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
+ llvm::SmallVector<unsigned> rwNums, wNums, rNums;
+
+ {
+ StringRef rest = asmText;
+ SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
+ while (!rest.empty() && rx.match(rest, &m)) {
+ unsigned num = 0;
+ (void)m[2].getAsInteger(10, num);
+
+ if (m[1].equals_insensitive(kReadWrite)) {
+ if (seenRW.insert(num).second)
+ rwNums.push_back(num);
+ } else if (m[1].equals_insensitive(kWriteOnly)) {
+ if (seenW.insert(num).second)
+ wNums.push_back(num);
+ } else {
+ if (seenR.insert(num).second)
+ rNums.push_back(num);
+ }
+
+ const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+ rest = rest.drop_front(advance);
+ }
+ }
+
+ llvm::sort(rwNums);
+ llvm::sort(wNums);
+ llvm::sort(rNums);
+
+ llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
+ unsigned nextId = 0;
+ for (unsigned n : rwNums)
+ rwMap[n] = nextId++;
+ for (unsigned n : wNums)
+ wMap[n] = nextId++;
+ for (unsigned n : rNums)
+ rMap[n] = nextId++;
+
+ std::string out;
+ out.reserve(asmText.size());
- if (!needsPackUnpack(interfaceOp))
- return llvm::to_vector<1>(results);
+ size_t prev = 0;
+ StringRef rest = asmText;
+ SmallVector<StringRef, 3> m;
+ while (!rest.empty() && rx.match(rest, &m)) {
+ // Compute absolute match bounds in the original buffer.
+ size_t absStart = (size_t)(m[0].data() - asmText.data());
+ size_t absEnd = absStart + m[0].size();
- SmallVector<mlir::Type> elems(results.begin(), results.end());
- auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
- return {sTy};
+ // Emit text before the match.
+ out.append(asmText.data() + prev, asmText.data() + absStart);
+
+ // Emit compact $K
+ unsigned num = 0;
+ (void)m[2].getAsInteger(10, num);
+ unsigned id = 0;
+ if (m[1].equals_insensitive(kReadWrite))
+ id = rwMap.lookup(num);
+ else if (m[1].equals_insensitive(kWriteOnly))
+ id = wMap.lookup(num);
+ else
+ id = rMap.lookup(num);
+
+ out.push_back('$');
+ out += std::to_string(id);
+
+ prev = absEnd;
+
+ // Advance search window.
+ const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+ rest = rest.drop_front(advance);
+ }
+
+ // Tail.
+ out.append(asmText.data() + prev, asmText.data() + asmText.size());
+ return out;
}
LLVM::InlineAsmOp PtxBuilder::build() {
- MLIRContext *ctx = interfaceOp->getContext();
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
- SmallVector<Type> resultTypes = packResultTypes(ctx, interfaceOp);
+ SmallVector<Type> resultTypes = packResultTypes(
+ interfaceOp, needsManualMapping, registerModifiers, ptxOperands);
// Remove the last comma from the constraints string.
if (!registerConstraints.empty() &&
registerConstraints[registerConstraints.size() - 1] == ',')
registerConstraints.pop_back();
+ registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
std::string ptxInstruction = interfaceOp.getPtx();
+ if (!needsManualMapping)
+ ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
// Add the predicate to the asm string.
if (interfaceOp.getPredicate().has_value() &&
@@ -169,33 +373,86 @@ void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
- // Case 1: no result
- if (inlineAsmOp->getNumResults() == 0) {
+ // Case 0: no result at all → just erase wrapper op.
+ if (!hasResult) {
rewriter.eraseOp(interfaceOp);
return;
}
- // Case 2: single result, forward it directly
- if (!needsPackUnpack(interfaceOp)) {
+ if (needsManualMapping) {
rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
return;
}
- // Case 3: multiple results were packed; unpack the struct.
- assert(mlir::LLVM::LLVMStructType::classof(
- inlineAsmOp.getResultTypes().front()) &&
- "Expected result type to be LLVMStructType when unpacking multiple "
- "results");
- auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
- inlineAsmOp.getResultTypes().front());
+ // Case 1: Simple path, return single scalar
+ if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+ if (inlineAsmOp->getNumResults() > 0) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ } else {
+ // RW-only case with no declared results: forward the RW value.
+ SmallVector<Value> results;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite) {
+ results.push_back(v);
+ break;
+ }
+ rewriter.replaceOp(interfaceOp, results);
+ }
+ return;
+ }
+
+ const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
- SmallVector<mlir::Value> unpacked;
+ // All multi-value paths produce a single struct result we need to unpack.
+ assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
+ "expected struct return for multi-result inline asm");
Value structVal = inlineAsmOp.getResult(0);
- for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
- Value unpackedValue = LLVM::ExtractValueOp::create(
- rewriter, interfaceOp->getLoc(), structVal, idx);
- unpacked.push_back(unpackedValue);
+ SmallVector<Value> unpacked =
+ extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
+
+ // Case 2: only declared results (no RW): replace the op with all unpacked.
+ if (!hasRW && interfaceOp->getResults().size() > 0) {
+ rewriter.replaceOp(interfaceOp, unpacked);
+ return;
+ }
+
+ // Case 3: RW-only (no declared results): update RW uses and erase wrapper.
+ if (hasRW && interfaceOp->getResults().size() == 0) {
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ rewriter.eraseOp(interfaceOp);
+ return;
}
- rewriter.replaceOp(interfaceOp, unpacked);
+ // Case 4: mixed (RW + declared results).
+ {
+ // First rewrite RW operands in place.
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ // The remaining unpacked values correspond to the declared results.
+ SmallVector<Value> tail;
+ tail.reserve(unpacked.size() - idx);
+ for (unsigned i = idx, e = unpacked.size(); i < e; ++i)
+ tail.push_back(unpacked[i]);
+
+ rewriter.replaceOp(interfaceOp, tail);
+ }
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index dbcc738b4419f..ae9134458095f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1123,7 +1123,7 @@ std::string NVVM::WgmmaMma...
[truncated]
|
joker-eph
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really cool, I don't have concerns scanning through it, hopefully someone can double check all the logic as well!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the NVVM inline_ptx operation to better support read-write operands and simplifies the inline assembly generation. The key improvement is replacing the manual struct packing/unpacking with automatic handling and introducing readable register mapping with named placeholders.
Key changes:
- Adds read-write (rw) operand support with automatic struct unpacking
- Introduces named register placeholders ({$rN}, {$wN}, {$rwN}) that are automatically mapped to numbered registers
- Simplifies the inline assembly generation by avoiding manual constraint specification
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/python/dialects/nvvm.py | Adds Python test for new inline_ptx syntax with rw operands |
| mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | Updates test cases to use new placeholder syntax and adds read-write operand tests |
| mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | Implements getAsmValues for InlinePtxOp to handle read-write operands |
| mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp | Major refactoring to support automatic placeholder rewriting and register constraint canonicalization |
| mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp | Updates to handle manual mapping flag from getAsmValues |
| mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | Updates inline_ptx operation definition to include readWriteArgs |
| mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td | Changes getAsmValues interface to return bool for manual mapping indication |
| mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h | Updates PtxBuilder constructor and adds manual mapping support |
The only nontrivial part is the register numbering. For example, in the snippet below the mapping should be: |
Key Features
1. Multiple SSA returns – no struct packing/unpacking required.
2. Automatic struct unpacking – values are directly usable.
3. Readable register mapping
* {$rwN} → read-write
* {$roN} → read-only
* {$woN} → write-only
4. Full read-write support (+ modifier).
5. Simplified operand specification – avoids cryptic "=r,=r,=f,=f,f,f,0,1" constraints.
6. Predicate support: PTX @p predication support
IR Example:
```
%wo0, %wo1 = nvvm.inline_ptx """
.reg .pred p;
setp.ge.s32 p, {$r0}, {$r1};
selp.s32 {$rw0}, {$r0}, {$r1}, p;
selp.s32 {$rw1}, {$r0}, {$r1}, p;
selp.s32 {$w0}, {$r0}, {$r1}, p;
selp.s32 {$w1}, {$r0}, {$r1}, p;
""" ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -> f32, f32
```
After lowering
```
%0 = llvm.inline_asm has_side_effects asm_dialect = att
"{
.reg .pred p;\
setp.ge.s32 p, $4, $5; \
selp.s32 $0, $4, $5, p;\
selp.s32 $1, $4, $5, p;\
selp.s32 $2, $4, $5, p;\
selp.s32 $3, $4, $5, p;\
}"
"=r,=r,=f,=f,f,f,0,1"
%c500_i32, %c400_i32, %cst, %cst_0
: (i32, i32, f32, f32)
-> !llvm.struct<(i32, i32, f32, f32)>
%1 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
%2 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
%3 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
%4 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
// Unpacked result from nvvm.inline_ptx
%5 = arith.addi %1, %2 : i32
// read only
%6 = arith.addf %cst, %cst_0 : f32
// write only
%7 = arith.addf %3, %4 : f32
```
341f7f2 to
e9c0aa2
Compare
|
|
||
| // CHECK-LABEL: @inline_ptx_multi_rw_r( | ||
| // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32) | ||
| llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _r here in the function name is tricky. I think you meant it for return values.
So, let us name it: multi_return_rw or multi_rw_return
durga4github
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except for a few queries.
Key Features
@ppredication supportIR Example:
After lowering