Skip to content

Commit c945022

Browse files
authored
[MLIR][NVVM] Support packed registers in inline_ptx (#154904)
Add support for packed registers with vectors. Example: ``` %wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};" ro(%src, %mask, %zero : vector<4xi8>, i32, i32) -> i32 ``` Here, `vector<4xi8>` is lowered to an `i32` register (i.e., an `r` in PTX).
1 parent 11f4be0 commit c945022

File tree

4 files changed

+138
-30
lines changed

4 files changed

+138
-30
lines changed

mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/BuiltinAttributes.h"
1919
#include "mlir/IR/PatternMatch.h"
2020
#include "mlir/IR/Value.h"
21+
#include "llvm/Support/LogicalResult.h"
2122

2223
namespace mlir {
2324
namespace NVVM {
@@ -82,7 +83,8 @@ class PtxBuilder {
8283
needsManualRegisterMapping(needsManualRegisterMapping) {}
8384

8485
/// Add an operand with the read/write input type.
85-
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
86+
LogicalResult insertValue(Value v,
87+
PTXRegisterMod itype = PTXRegisterMod::Read);
8688

8789
/// Builds the inline assembly Op and returns it. The `insertValue` needs to
8890
/// be called to pass operands before building the PTX.

mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Pass/Pass.h"
2727
#include "mlir/Support/LLVM.h"
2828
#include "llvm/Support/DebugLog.h"
29+
#include "llvm/Support/LogicalResult.h"
2930
#include "llvm/Support/raw_ostream.h"
3031

3132
#define DEBUG_TYPE "nvvm-to-llvm"
@@ -62,7 +63,8 @@ struct PtxLowering
6263
PtxBuilder generator(op, rewriter, needsManualMapping);
6364
for (auto &[asmValue, modifier] : asmValues) {
6465
LDBG() << asmValue << "\t Modifier : " << modifier;
65-
generator.insertValue(asmValue, modifier);
66+
if (failed(generator.insertValue(asmValue, modifier)))
67+
return failure();
6668
}
6769

6870
generator.buildAndReplaceOp();

mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp

Lines changed: 97 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
15+
#include "mlir/IR/BuiltinTypes.h"
16+
#include "mlir/IR/Diagnostics.h"
17+
#include "mlir/IR/Location.h"
18+
#include "mlir/IR/MLIRContext.h"
1519

20+
#include "mlir/Support/LLVM.h"
1621
#include "llvm/ADT/StringExtras.h"
22+
#include "llvm/ADT/TypeSwitch.h"
1723
#include "llvm/Support/DebugLog.h"
1824
#include "llvm/Support/FormatVariadic.h"
25+
#include "llvm/Support/LogicalResult.h"
1926
#include "llvm/Support/Regex.h"
2027

2128
#define DEBUG_TYPE "ptx-builder"
@@ -31,35 +38,88 @@ using namespace NVVM;
3138

3239
static constexpr int64_t kSharedMemorySpace = 3;
3340

34-
static char getRegisterType(Type type) {
35-
if (type.isInteger(1))
36-
return 'b';
37-
if (type.isInteger(16))
38-
return 'h';
39-
if (type.isInteger(32))
40-
return 'r';
41-
if (type.isInteger(64))
42-
return 'l';
43-
if (type.isF32())
44-
return 'f';
45-
if (type.isF64())
46-
return 'd';
47-
if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
48-
// Shared address spaces is addressed with 32-bit pointers.
49-
if (ptr.getAddressSpace() == kSharedMemorySpace) {
41+
static FailureOr<char> getRegisterType(Type type, Location loc) {
42+
MLIRContext *ctx = type.getContext();
43+
auto i16 = IntegerType::get(ctx, 16);
44+
auto i32 = IntegerType::get(ctx, 32);
45+
auto f32 = Float32Type::get(ctx);
46+
47+
auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
48+
if (type.isInteger(1))
49+
return 'b';
50+
if (type.isInteger(16))
51+
return 'h';
52+
if (type.isInteger(32))
5053
return 'r';
54+
if (type.isInteger(64))
55+
return 'l';
56+
if (type.isF32())
57+
return 'f';
58+
if (type.isF64())
59+
return 'd';
60+
if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
61+
// Shared address spaces is addressed with 32-bit pointers.
62+
if (ptr.getAddressSpace() == kSharedMemorySpace) {
63+
return 'r';
64+
}
65+
return 'l';
66+
}
67+
// register type for struct is not supported.
68+
mlir::emitError(
69+
loc, "The register type could not be deduced from MLIR type. The ")
70+
<< type
71+
<< " is not supported. Supported types are:"
72+
"i1, i16, i32, i64, f32, f64,"
73+
"pointers.\nPlease use llvm.bitcast if you have different type. "
74+
"\nSee the constraints from here: "
75+
"https://docs.nvidia.com/cuda/inline-ptx-assembly/"
76+
"index.html#constraints";
77+
return failure();
78+
};
79+
80+
// Packed registers
81+
if (auto v = dyn_cast<VectorType>(type)) {
82+
assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
83+
84+
int64_t lanes = v.getNumElements();
85+
Type elem = v.getElementType();
86+
87+
// Case 1. Single vector
88+
if (lanes <= 1)
89+
return getRegisterTypeForScalar(elem);
90+
91+
// Case 2. Packed registers
92+
Type widened = elem;
93+
switch (lanes) {
94+
95+
case 2:
96+
if (elem.isF16() || elem.isBF16()) // vector<2xf16>
97+
widened = f32;
98+
else if (elem.isFloat(8)) // vector<2xf8>
99+
widened = i16;
100+
break;
101+
case 4:
102+
if (elem.isInteger(8)) // vector<i8x4>
103+
widened = i32;
104+
else if (elem.isFloat(8)) // vector<f8x4>
105+
widened = f32;
106+
else if (elem.isFloat(4)) // vector<f4x4>
107+
widened = i16;
108+
break;
109+
// Other packing is not supported
110+
default:
111+
break;
51112
}
52-
return 'l';
113+
return getRegisterTypeForScalar(widened);
53114
}
54-
// register type for struct is not supported.
55-
llvm_unreachable("The register type could not deduced from MLIR type");
56-
return '?';
115+
116+
return getRegisterTypeForScalar(type);
57117
}
58118

59-
static char getRegisterType(Value v) {
119+
static FailureOr<char> getRegisterType(Value v, Location loc) {
60120
if (v.getDefiningOp<LLVM::ConstantOp>())
61121
return 'n';
62-
return getRegisterType(v.getType());
122+
return getRegisterType(v.getType(), loc);
63123
}
64124

65125
/// Extract every element of a struct value.
@@ -75,10 +135,11 @@ static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
75135
return elems;
76136
}
77137

78-
void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
138+
LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
79139
LDBG() << v << "\t Modifier : " << itype << "\n";
80140
registerModifiers.push_back(itype);
81141

142+
Location loc = interfaceOp->getLoc();
82143
auto getModifier = [&]() -> const char * {
83144
switch (itype) {
84145
case PTXRegisterMod::Read:
@@ -111,21 +172,29 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
111172
}
112173
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
113174
if (itype != PTXRegisterMod::Write) {
114-
Value extractValue = LLVM::ExtractValueOp::create(
115-
rewriter, interfaceOp->getLoc(), v, idx);
175+
Value extractValue =
176+
LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
116177
addValue(extractValue);
117178
}
118179
if (itype == PTXRegisterMod::ReadWrite) {
119180
ss << idx << ",";
120181
} else {
121-
ss << getModifier() << getRegisterType(t) << ",";
182+
FailureOr<char> regType = getRegisterType(t, loc);
183+
if (failed(regType))
184+
return rewriter.notifyMatchFailure(loc,
185+
"failed to get register type");
186+
ss << getModifier() << regType.value() << ",";
122187
}
123188
}
124-
return;
189+
return success();
125190
}
126191
// Handle Scalars
127192
addValue(v);
128-
ss << getModifier() << getRegisterType(v) << ",";
193+
FailureOr<char> regType = getRegisterType(v, loc);
194+
if (failed(regType))
195+
return rewriter.notifyMatchFailure(loc, "failed to get register type");
196+
ss << getModifier() << regType.value() << ",";
197+
return success();
129198
}
130199

131200
/// Check if the operation needs to pack and unpack results.

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,38 @@ llvm.func @nvvm_pmevent() {
745745
nvvm.pmevent id = 4
746746
llvm.return
747747
}
748+
749+
// -----
750+
751+
llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>, %mask : i32, %zero: i32) {
752+
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,r,r" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
753+
%wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
754+
ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
755+
-> i32
756+
llvm.return
757+
}
758+
759+
llvm.func @inline_ptx_pack_2bf16(%a : f32, %b : f32) {
760+
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rn.satfinite.bf16x2.f32 $0, $1, $2;", "=f,f,f" %{{.*}}, %{{.*}} : (f32, f32) -> vector<2xbf16>
761+
%wo0 = nvvm.inline_ptx "cvt.rn.satfinite.bf16x2.f32 {$w0}, {$r0}, {$r1};"
762+
ro(%a, %b : f32, f32)
763+
-> vector<2xbf16>
764+
llvm.return
765+
}
766+
767+
llvm.func @inline_ptx_cvt_rn_e4m3x2_f16x2(%a : i16) {
768+
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rz.satfinite.ue8m0x2.bf16x2 $0, $1", "=f,h" %{{.*}} : (i16) -> vector<2xbf16>
769+
%wo0 = nvvm.inline_ptx "cvt.rz.satfinite.ue8m0x2.bf16x2 {$w0}, {$r0}"
770+
ro(%a : i16)
771+
-> vector<2xbf16>
772+
llvm.return
773+
}
774+
775+
llvm.func @cvt_i8_bf16(%a : i8) {
776+
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .b16 r;\0A\09.reg .b8 s;\0A\09mov.b16 {s,_}, $0;\0A\09cvt.rn.bf16.s8 r, s;\0A\09mov.b16 $1, r;\0A\09", "=h,h" %{{.*}} : (i16) -> i16
777+
%za = llvm.zext %a : i8 to i16
778+
%wo0 = nvvm.inline_ptx "{\n\t.reg .b16 r;\n\t.reg .b8 s;\n\tmov.b16 {s,_}, {$w0};\n\tcvt.rn.bf16.s8 r, s;\n\tmov.b16 {$r0}, r;\n\t"
779+
ro(%za : i16)
780+
-> i16
781+
llvm.return
782+
}

0 commit comments

Comments
 (0)