Skip to content

Commit 5c36fb3

Browse files
authored
[MLIR][NVVM] Improve inline_ptx, add readwrite support (#154358)
Key Features 1. Multiple SSA returns – no struct packing/unpacking required. 2. Automatic struct unpacking – values are directly usable. 3. Readable register mapping * {$rwN} → read-write * {$roN} → read-only * {$woN} → write-only 4. Full read-write support (+ modifier). 5. Simplified operand specification – avoids cryptic "=r,=r,=f,=f,f,f,0,1" constraints. 6. Predicate support: PTX `@p` predication support IR Example: ``` %wo0, %wo1 = nvvm.inline_ptx """ .reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0}, {$r1}, p; selp.s32 {$rw1}, {$r0}, {$r1}, p; selp.s32 {$w0}, {$r0}, {$r1}, p; selp.s32 {$w1}, {$r0}, {$r1}, p; """ ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -> f32, f32 ``` After lowering ``` %0 = llvm.inline_asm has_side_effects asm_dialect = att "{ .reg .pred p;\ setp.ge.s32 p, $4, $5; \ selp.s32 $0, $4, $5, p;\ selp.s32 $1, $4, $5, p;\ selp.s32 $2, $4, $5, p;\ selp.s32 $3, $4, $5, p;\ }" "=r,=r,=f,=f,f,f,0,1" %c500_i32, %c400_i32, %cst, %cst_0 : (i32, i32, f32, f32) -> !llvm.struct<(i32, i32, f32, f32)> %1 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> %2 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> %3 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> %4 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> // Unpacked result from nvvm.inline_ptx %5 = arith.addi %1, %2 : i32 // read only %6 = arith.addf %cst, %cst_0 : f32 // write only %7 = arith.addf %3, %4 : f32 ```
1 parent 1b0b59a commit 5c36fb3

File tree

8 files changed

+481
-66
lines changed

8 files changed

+481
-66
lines changed

mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ namespace NVVM {
2626
enum class PTXRegisterMod {
2727
/// Read register with no modifier
2828
Read = 0,
29-
/// Read register with '+' modifier
29+
/// Write register with '=' modifier
3030
Write = 2,
31-
/// Read register with '=' modifier.
32-
/// Note that, this is not natively supported by LLVM, but it is possible to
33-
/// set read and write for the same operand.
31+
/// ReadWrite register with '+' modifier.
32+
/// Note that, this is not natively supported by LLVM, the Interface does
33+
/// mapping
3434
ReadWrite = 1,
3535
};
3636

@@ -67,13 +67,19 @@ class PtxBuilder {
6767
SmallVector<Value> ptxOperands;
6868
// Register constraints (read, write, readwrite) and register data types
6969
std::string registerConstraints;
70-
70+
// Modifiers
71+
SmallVector<PTXRegisterMod> registerModifiers;
72+
// Has return value as write-only or read-write
7173
bool hasResult = false;
74+
// Indicates if the Op will handle the register mapping manually.
75+
bool needsManualRegisterMapping = false;
7276

7377
public:
7478
/// Single constructor that only initializes members.
75-
PtxBuilder(Operation *op, PatternRewriter &rewriter)
76-
: interfaceOp(op), rewriter(rewriter) {}
79+
PtxBuilder(Operation *op, PatternRewriter &rewriter,
80+
bool needsManualRegisterMapping = false)
81+
: interfaceOp(op), rewriter(rewriter),
82+
needsManualRegisterMapping(needsManualRegisterMapping) {}
7783

7884
/// Add an operand with the read/write input type.
7985
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
@@ -87,6 +93,16 @@ class PtxBuilder {
8793
void buildAndReplaceOp();
8894
};
8995

96+
/// Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the
97+
/// PTX code.
98+
void countPlaceholderNumbers(StringRef ptxCode,
99+
llvm::SmallDenseSet<unsigned> &seenRW,
100+
llvm::SmallDenseSet<unsigned> &seenW,
101+
llvm::SmallDenseSet<unsigned> &seenR,
102+
llvm::SmallVectorImpl<unsigned> &rwNums,
103+
llvm::SmallVectorImpl<unsigned> &wNums,
104+
llvm::SmallVectorImpl<unsigned> &rNums);
105+
90106
} // namespace NVVM
91107
} // namespace mlir
92108

mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
124124
following this order:
125125
1) Adds results
126126
2) Adds operands
127-
3) Adds attributes
127+
3) Adds attributes
128+
Returns true if the OP is going to do register mapping itself
128129
}],
129-
/*retType=*/"void",
130+
/*retType=*/"bool",
130131
/*methodName=*/"getAsmValues",
131132
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
132-
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
133+
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues
134+
),
133135
/*methodBody=*/"",
134136
/*defaultImpl=*/ [{
135137
mlir::Operation* op = $_op;
136138

137139
// Step 1. Add results
138-
for (auto val : op->getResults())
139-
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
140+
for (auto val : op->getResults())
141+
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
140142

141143
// Step 2. Add operands
142144
for (auto val : op->getOperands())
@@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
149151
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
150152
}
151153
}
154+
return false; // No manual mapping needed
152155
}]
153156
>
154157
];

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,16 +315,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
315315
}];
316316

317317
let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
318+
Variadic<AnyType>:$readWriteArgs,
318319
StrAttr:$ptxCode,
319320
PtxPredicate:$predicate);
320321

321322
let results = (outs Variadic<AnyType>:$writeOnlyArgs);
322-
323-
let assemblyFormat = [{
324-
$ptxCode `(` $readOnlyArgs `)`
325-
(`,` `predicate` `=` $predicate^)? attr-dict
326-
`:` type(operands)
327-
(`->` type($writeOnlyArgs)^)?
323+
324+
let assemblyFormat = [{
325+
$ptxCode
326+
( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )?
327+
( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )?
328+
(`,` `predicate` `=` $predicate^)?
329+
attr-dict
330+
( `->` type($writeOnlyArgs)^ )?
328331
}];
329332

330333
let extraClassDefinition = [{
@@ -333,6 +336,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
333336
return std::string(ptxInstStr.data());
334337
}
335338
}];
339+
340+
let extraClassDeclaration = [{
341+
bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
342+
}];
336343
}
337344

338345
//===----------------------------------------------------------------------===//
@@ -3057,8 +3064,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
30573064
let hasVerifier = 1;
30583065

30593066
let extraClassDeclaration = [{
3060-
void getAsmValues(RewriterBase &rewriter,
3061-
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
3067+
bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
30623068
}];
30633069
}
30643070

mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ struct PtxLowering
5757

5858
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
5959
LDBG() << op.getPtx();
60-
PtxBuilder generator(op, rewriter);
6160

62-
op.getAsmValues(rewriter, asmValues);
61+
bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
62+
PtxBuilder generator(op, rewriter, needsManualMapping);
6363
for (auto &[asmValue, modifier] : asmValues) {
6464
LDBG() << asmValue << "\t Modifier : " << modifier;
6565
generator.insertValue(asmValue, modifier);

0 commit comments

Comments
 (0)