Skip to content

Commit 5d300af

Browse files
authored
[MLIR][NVVM] Add support for multiple return values in inline_ptx (#153774)
This PR adds the ability for `nvvm.inline_ptx` to return multiple values, matching the expected semantics in PTX while respecting LLVM’s constraints. LLVM’s `inline_asm` op does not natively support multiple returns — instead, it requires packing results into an LLVM `struct` and then extracting them. This PR implements automatic packing/unpacking so that multiple return values can be expressed naturally in MLIR without extra user boilerplate. **Example** MLIR: ``` %r1, %r2 = nvvm.inline_ptx "{ .reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2, $3, p; selp.s32 $1, $2, $3, !p; }" (%a, %b) : i32, i32 -> i32, i32 %r3 = llvm.add %r1, %r2 : i32 ``` Lowered LLVM IR: ``` %1 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %a, %b : (i32, i32) -> !llvm.struct<(i32, i32)> %2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32)> %3 = llvm.extractvalue %1[1] : !llvm.struct<(i32, i32)> %4 = llvm.add %2, %3 : i32 ```
1 parent e6e874c commit 5d300af

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,32 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
107107
ss << getModifier() << getRegisterType(v) << ",";
108108
}
109109

110+
/// Check if the operation needs to pack and unpack results.
111+
static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
112+
return interfaceOp->getNumResults() > 1;
113+
}
114+
115+
/// Pack the result types of the interface operation.
116+
/// If the operation has multiple results, it packs them into a struct
117+
/// type. Otherwise, it returns the original result types.
118+
static SmallVector<Type> packResultTypes(MLIRContext *ctx,
119+
BasicPtxBuilderInterface interfaceOp) {
120+
TypeRange results = interfaceOp->getResultTypes();
121+
122+
if (!needsPackUnpack(interfaceOp))
123+
return llvm::to_vector<1>(results);
124+
125+
SmallVector<mlir::Type> elems(results.begin(), results.end());
126+
auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
127+
return {sTy};
128+
}
129+
110130
LLVM::InlineAsmOp PtxBuilder::build() {
131+
MLIRContext *ctx = interfaceOp->getContext();
111132
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
112133
LLVM::AsmDialect::AD_ATT);
113134

114-
auto resultTypes = interfaceOp->getResultTypes();
135+
SmallVector<Type> resultTypes = packResultTypes(ctx, interfaceOp);
115136

116137
// Remove the last comma from the constraints string.
117138
if (!registerConstraints.empty() &&
@@ -136,7 +157,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
136157
rewriter, interfaceOp->getLoc(),
137158
/*result types=*/resultTypes,
138159
/*operands=*/ptxOperands,
139-
/*asm_string=*/llvm::StringRef(ptxInstruction),
160+
/*asm_string=*/ptxInstruction,
140161
/*constraints=*/registerConstraints.data(),
141162
/*has_side_effects=*/interfaceOp.hasSideEffect(),
142163
/*is_align_stack=*/false, LLVM::TailCallKind::None,
@@ -147,9 +168,34 @@ LLVM::InlineAsmOp PtxBuilder::build() {
147168
void PtxBuilder::buildAndReplaceOp() {
148169
LLVM::InlineAsmOp inlineAsmOp = build();
149170
LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
150-
if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
151-
rewriter.replaceOp(interfaceOp, inlineAsmOp);
152-
} else {
171+
172+
// Case 1: no result
173+
if (inlineAsmOp->getNumResults() == 0) {
153174
rewriter.eraseOp(interfaceOp);
175+
return;
176+
}
177+
178+
// Case 2: single result, forward it directly
179+
if (!needsPackUnpack(interfaceOp)) {
180+
rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
181+
return;
154182
}
183+
184+
// Case 3: multiple results were packed; unpack the struct.
185+
assert(mlir::LLVM::LLVMStructType::classof(
186+
inlineAsmOp.getResultTypes().front()) &&
187+
"Expected result type to be LLVMStructType when unpacking multiple "
188+
"results");
189+
auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
190+
inlineAsmOp.getResultTypes().front());
191+
192+
SmallVector<mlir::Value> unpacked;
193+
Value structVal = inlineAsmOp.getResult(0);
194+
for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
195+
Value unpackedValue = LLVM::ExtractValueOp::create(
196+
rewriter, interfaceOp->getLoc(), structVal, idx);
197+
unpacked.push_back(unpackedValue);
198+
}
199+
200+
rewriter.replaceOp(interfaceOp, unpacked);
155201
}

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,18 @@ llvm.func @ex2(%input : f32, %pred : i1) {
683683
llvm.return
684684
}
685685

686+
// CHECK-LABEL: @multi_return(
687+
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
688+
llvm.func @multi_return(%a : i32, %b : i32) -> i32 {
689+
// CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
690+
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)>
691+
// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)>
692+
// CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
693+
// CHECK: llvm.return %[[S4]] : i32
694+
%r1, %r2 = nvvm.inline_ptx "{\n\t .reg .pred p;\n\t setp.ge.s32 p, $2, $3;\n\t selp.s32 $0, $2, $3, p;\n\t selp.s32 $1, $2, $3, !p;\n\t}\n" (%a, %b) : i32,i32 -> i32,i32
695+
%r3 = llvm.add %r1, %r2 : i32
696+
llvm.return %r3 : i32
697+
}
686698
// -----
687699

688700
// CHECK-LABEL: @nvvm_pmevent

0 commit comments

Comments
 (0)