Skip to content

Commit 16a02a4

Browse files
committed
[MLIR][NVVM] Improve inline_ptx, add readwrite support
Key Features 1. Multiple SSA returns – no struct packing/unpacking required. 2. Automatic struct unpacking – values are directly usable. 3. Readable register mapping * {$rwN} → read-write * {$roN} → read-only * {$woN} → write-only 4. Full read-write support (+ modifier). 5. Simplified operand specification – avoids cryptic "=r,=r,=f,=f,f,f,0,1" constraints. 6. Predicate support: PTX @p predication support IR Example: ``` %wo0, %wo1 = nvvm.inline_ptx """ .reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0}, {$r1}, p; selp.s32 {$rw1}, {$r0}, {$r1}, p; selp.s32 {$w0}, {$r0}, {$r1}, p; selp.s32 {$w1}, {$r0}, {$r1}, p; """ ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -> f32, f32 ``` After lowering ``` %0 = llvm.inline_asm has_side_effects asm_dialect = att "{ .reg .pred p;\ setp.ge.s32 p, $4, $5; \ selp.s32 $0, $4, $5, p;\ selp.s32 $1, $4, $5, p;\ selp.s32 $2, $4, $5, p;\ selp.s32 $3, $4, $5, p;\ }" "=r,=r,=f,=f,f,f,0,1" %c500_i32, %c400_i32, %cst, %cst_0 : (i32, i32, f32, f32) -> !llvm.struct<(i32, i32, f32, f32)> %1 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> %2 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> %3 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> %4 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)> // Unpacked result from nvvm.inline_ptx %5 = arith.addi %1, %2 : i32 // read only %6 = arith.addf %cst, %cst_0 : f32 // write only %7 = arith.addf %3, %4 : f32 ```
1 parent 5db67e1 commit 16a02a4

File tree

8 files changed

+435
-59
lines changed

8 files changed

+435
-59
lines changed

mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ namespace NVVM {
2626
enum class PTXRegisterMod {
2727
/// Read register with no modifier
2828
Read = 0,
29-
/// Read register with '+' modifier
29+
/// Read register with '=' modifier
3030
Write = 2,
31-
/// Read register with '=' modifier.
31+
/// Read register with '+' modifier.
3232
/// Note that, this is not natively supported by LLVM, but it is possible to
3333
/// set read and write for the same operand.
3434
ReadWrite = 1,
@@ -67,13 +67,17 @@ class PtxBuilder {
6767
SmallVector<Value> ptxOperands;
6868
// Register constraints (read, write, readwrite) and register data types
6969
std::string registerConstraints;
70-
70+
// Modifiers
71+
SmallVector<PTXRegisterMod> registerModifiers;
7172
bool hasResult = false;
73+
bool needsManualMapping = false;
7274

7375
public:
7476
/// Single constructor that only initializes members.
75-
PtxBuilder(Operation *op, PatternRewriter &rewriter)
76-
: interfaceOp(op), rewriter(rewriter) {}
77+
PtxBuilder(Operation *op, PatternRewriter &rewriter,
78+
bool needsManualMapping = false)
79+
: interfaceOp(op), rewriter(rewriter),
80+
needsManualMapping(needsManualMapping) {}
7781

7882
/// Add an operand with the read/write input type.
7983
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);

mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
124124
following this order:
125125
1) Adds results
126126
2) Adds operands
127-
3) Adds attributes
127+
3) Adds attributes
128+
Returns true if it does the mapping manually
128129
}],
129-
/*retType=*/"void",
130+
/*retType=*/"bool",
130131
/*methodName=*/"getAsmValues",
131132
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
132-
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
133+
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues
134+
),
133135
/*methodBody=*/"",
134136
/*defaultImpl=*/ [{
135137
mlir::Operation* op = $_op;
136138

137139
// Step 1. Add results
138-
for (auto val : op->getResults())
139-
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
140+
for (auto val : op->getResults())
141+
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
140142

141143
// Step 2. Add operands
142144
for (auto val : op->getOperands())
@@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
149151
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
150152
}
151153
}
154+
return false; // No needs manual mapping
152155
}]
153156
>
154157
];

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,16 +315,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
315315
}];
316316

317317
let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
318+
Variadic<AnyType>:$readWriteArgs,
318319
StrAttr:$ptxCode,
319320
PtxPredicate:$predicate);
320321

321322
let results = (outs Variadic<AnyType>:$writeOnlyArgs);
322-
323-
let assemblyFormat = [{
324-
$ptxCode `(` $readOnlyArgs `)`
325-
(`,` `predicate` `=` $predicate^)? attr-dict
326-
`:` type(operands)
327-
(`->` type($writeOnlyArgs)^)?
323+
324+
let assemblyFormat = [{
325+
$ptxCode
326+
( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )?
327+
( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )?
328+
(`,` `predicate` `=` $predicate^)?
329+
attr-dict
330+
( `->` type($writeOnlyArgs)^ )?
328331
}];
329332

330333
let extraClassDefinition = [{
@@ -333,6 +336,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
333336
return std::string(ptxInstStr.data());
334337
}
335338
}];
339+
340+
let extraClassDeclaration = [{
341+
bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
342+
}];
336343
}
337344

338345
//===----------------------------------------------------------------------===//
@@ -3027,8 +3034,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
30273034
let hasVerifier = 1;
30283035

30293036
let extraClassDeclaration = [{
3030-
void getAsmValues(RewriterBase &rewriter,
3031-
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
3037+
bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
30323038
}];
30333039
}
30343040

mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ struct PtxLowering
5757

5858
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
5959
LDBG() << op.getPtx();
60-
PtxBuilder generator(op, rewriter);
6160

62-
op.getAsmValues(rewriter, asmValues);
61+
bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
62+
PtxBuilder generator(op, rewriter, needsManualMapping);
6363
for (auto &[asmValue, modifier] : asmValues) {
6464
LDBG() << asmValue << "\t Modifier : " << modifier;
6565
generator.insertValue(asmValue, modifier);

0 commit comments

Comments
 (0)