Skip to content

Commit 6b07eb8

Browse files
Wolfram70mikolaj-pirog
authored andcommitted
[MLIR][NVVM] Add support for converting fp4/6/8 to fp16x2 (llvm#162439)
This change adds the following NVVM dialect Ops for converting fp4/6/8 to fp16x2: - `convert.f4x2.to.f16x2` - `convert.f6x2.to.f16x2` - `convert.f8x2.to.f16x2` - `convert.f8x2.to.bf16x2` Tests are added in `convert_fp4x2.mlir`, `convert_fp6x2.mlir`, and `convert_fp8x2.mlir`. PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
1 parent 9e8cf04 commit 6b07eb8

File tree

6 files changed

+287
-0
lines changed

6 files changed

+287
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,55 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
18721872
}];
18731873
}
18741874

1875+
class NVVM_ConvertToFP16x2Op_Base <string srcType, Type srcArgType, string dstType>
1876+
: NVVM_Op<"convert." # !tolower(srcType) # "x2.to." # !tolower(dstType) # "x2"> {
1877+
let summary = "Convert a pair of " # !tolower(srcType) # " inputs to " # !tolower(dstType) # "x2";
1878+
let description = [{
1879+
This Op converts the given }] # !tolower(srcType) # [{ inputs in a }] #
1880+
!if(!eq(srcType, "F4"), "packed i8", "i8x2 vector") # [{ to }] #
1881+
!tolower(dstType) # [{.
1882+
1883+
The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
1884+
}] #
1885+
!if(!eq(dstType, "F16"),
1886+
[{The `relu` attribute, when set, lowers to the '.relu' variant of
1887+
the cvt instruction."}], "") # [{
1888+
1889+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1890+
}];
1891+
let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
1892+
let arguments = !if(!eq(dstType, "F16"),
1893+
(ins srcArgType:$src,
1894+
DefaultValuedAttr<BoolAttr, "false">:$relu,
1895+
TypeAttr:$srcType),
1896+
(ins srcArgType:$src,
1897+
TypeAttr:$srcType));
1898+
let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
1899+
let hasVerifier = 1;
1900+
1901+
let extraClassDeclaration = [{
1902+
static IDArgPair
1903+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1904+
llvm::IRBuilderBase &builder);
1905+
}];
1906+
1907+
string llvmBuilder = [{
1908+
auto [intId, args] =
1909+
NVVM::Convert}] # srcType # [{x2To}] # dstType #
1910+
[{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
1911+
$dst = createIntrinsicCall(builder, intId, args);
1912+
}];
1913+
}
1914+
1915+
def NVVM_ConvertF8x2ToF16x2Op :
1916+
NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "F16">;
1917+
def NVVM_ConvertF8x2ToBF16x2Op :
1918+
NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "BF16">;
1919+
def NVVM_ConvertF6x2ToF16x2Op :
1920+
NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "F16">;
1921+
def NVVM_ConvertF4x2ToF16x2Op :
1922+
NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
1923+
18751924
//===----------------------------------------------------------------------===//
18761925
// NVVM MMA Ops
18771926
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,51 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
320320
return success();
321321
}
322322

323+
LogicalResult ConvertF8x2ToF16x2Op::verify() {
324+
mlir::MLIRContext *ctx = getContext();
325+
326+
if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
327+
return emitOpError("Only ")
328+
<< mlir::Float8E4M3FNType::get(ctx) << " and "
329+
<< mlir::Float8E5M2Type::get(ctx)
330+
<< " types are supported for conversions from f8x2 to f16x2.";
331+
332+
return success();
333+
}
334+
335+
LogicalResult ConvertF8x2ToBF16x2Op::verify() {
336+
mlir::MLIRContext *ctx = getContext();
337+
if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
338+
return emitOpError("Only ")
339+
<< mlir::Float8E8M0FNUType::get(ctx)
340+
<< " type is supported for conversions from f8x2 to bf16x2.";
341+
342+
return success();
343+
}
344+
345+
LogicalResult ConvertF6x2ToF16x2Op::verify() {
346+
mlir::MLIRContext *ctx = getContext();
347+
348+
if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
349+
return emitOpError("Only ")
350+
<< mlir::Float6E2M3FNType::get(ctx) << " and "
351+
<< mlir::Float6E3M2FNType::get(ctx)
352+
<< " types are supported for conversions from f6x2 to f16x2.";
353+
354+
return success();
355+
}
356+
357+
LogicalResult ConvertF4x2ToF16x2Op::verify() {
358+
mlir::MLIRContext *ctx = getContext();
359+
360+
if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
361+
return emitOpError("Only ")
362+
<< mlir::Float4E2M1FNType::get(ctx)
363+
<< " type is supported for conversions from f4x2 to f16x2.";
364+
365+
return success();
366+
}
367+
323368
LogicalResult BulkStoreOp::verify() {
324369
if (getInitVal() != 0)
325370
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2187,6 +2232,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
21872232
}
21882233
}
21892234

2235+
NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
2236+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2237+
auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
2238+
2239+
bool hasRelu = curOp.getRelu();
2240+
2241+
llvm::Intrinsic::ID intId =
2242+
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
2243+
.Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
2244+
return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
2245+
: llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
2246+
})
2247+
.Case<Float8E5M2Type>([&](Float8E5M2Type type) {
2248+
return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
2249+
: llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
2250+
})
2251+
.Default([](mlir::Type type) {
2252+
llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
2253+
return llvm::Intrinsic::not_intrinsic;
2254+
});
2255+
2256+
llvm::Value *packedI16 =
2257+
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2258+
llvm::Type::getInt16Ty(builder.getContext()));
2259+
2260+
return {intId, {packedI16}};
2261+
}
2262+
2263+
NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
2264+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2265+
auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
2266+
2267+
llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
2268+
llvm::Value *packedI16 =
2269+
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2270+
llvm::Type::getInt16Ty(builder.getContext()));
2271+
2272+
return {intId, {packedI16}};
2273+
}
2274+
2275+
NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
2276+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2277+
auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
2278+
2279+
bool hasRelu = curOp.getRelu();
2280+
2281+
llvm::Intrinsic::ID intId =
2282+
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
2283+
.Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
2284+
return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
2285+
: llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
2286+
})
2287+
.Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
2288+
return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
2289+
: llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
2290+
})
2291+
.Default([](mlir::Type type) {
2292+
llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
2293+
return llvm::Intrinsic::not_intrinsic;
2294+
});
2295+
2296+
llvm::Value *packedI16 =
2297+
builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2298+
llvm::Type::getInt16Ty(builder.getContext()));
2299+
2300+
return {intId, {packedI16}};
2301+
}
2302+
2303+
NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
2304+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2305+
auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
2306+
2307+
bool hasRelu = curOp.getRelu();
2308+
2309+
llvm::Intrinsic::ID intId =
2310+
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
2311+
.Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
2312+
return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
2313+
: llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
2314+
})
2315+
.Default([](mlir::Type type) {
2316+
llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
2317+
return llvm::Intrinsic::not_intrinsic;
2318+
});
2319+
2320+
llvm::Value *extendedI16 =
2321+
builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
2322+
llvm::Type::getInt16Ty(builder.getContext()));
2323+
2324+
return {intId, {extendedI16}};
2325+
}
2326+
21902327
llvm::Intrinsic::ID
21912328
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
21922329
LLVM::ModuleTranslation &mt,

mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,14 @@ llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
1010
%res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
1111
llvm.return
1212
}
13+
14+
// CHECK-LABEL: @convert_f4x2_to_f16x2
15+
llvm.func @convert_f4x2_to_f16x2(%src : i8) {
16+
// CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
17+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %[[res1]])
18+
%res1 = nvvm.convert.f4x2.to.f16x2 %src : i8 (f4E2M1FN)-> vector<2xf16>
19+
// CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16
20+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %[[res2]])
21+
%res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16>
22+
llvm.return
23+
}

mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,27 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
1919
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN)
2020
llvm.return
2121
}
22+
23+
// -----
24+
25+
// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
26+
llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
27+
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
28+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn(i16 %[[res1]])
29+
%res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
30+
// CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
31+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn.relu(i16 %[[res2]])
32+
%res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
33+
llvm.return
34+
}
35+
36+
// CHECK-LABEL: @convert_f6x2_to_f16x2_e3m2
37+
llvm.func @convert_f6x2_to_f16x2_e3m2(%src : vector<2xi8>) {
38+
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
39+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn(i16 %[[res1]])
40+
%res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
41+
// CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
42+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 %[[res2]])
43+
%res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
44+
llvm.return
45+
}

mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,37 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
100100
%res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
101101
llvm.return
102102
}
103+
104+
// -----
105+
106+
// CHECK-LABEL: @convert_f8x2_to_f16x2
107+
llvm.func @convert_f8x2_to_f16x2_e4m3(%src : vector<2xi8>) {
108+
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
109+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %[[res1]])
110+
%res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
111+
// CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
112+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %[[res2]])
113+
%res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
114+
llvm.return
115+
}
116+
117+
// CHECK-LABEL: @convert_f8x2_to_f16x2_e5m2
118+
llvm.func @convert_f8x2_to_f16x2_e5m2(%src : vector<2xi8>) {
119+
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
120+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %[[res1]])
121+
%res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E5M2)-> vector<2xf16>
122+
// CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
123+
// CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %[[res2]])
124+
%res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E5M2)-> vector<2xf16>
125+
llvm.return
126+
}
127+
128+
// -----
129+
130+
// CHECK-LABEL: @convert_f8x2_to_bf16x2_ue8m0
131+
llvm.func @convert_f8x2_to_bf16x2_ue8m0(%src : vector<2xi8>) {
132+
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
133+
// CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %[[res1]])
134+
%res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E8M0FNU)-> vector<2xbf16>
135+
llvm.return
136+
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,38 @@ llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
262262

263263
// -----
264264

265+
llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
266+
// expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f8x2 to f16x2.}}
267+
%res = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xf16>
268+
llvm.return
269+
}
270+
271+
// -----
272+
273+
llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
274+
// expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from f8x2 to bf16x2.}}
275+
%res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
276+
llvm.return
277+
}
278+
279+
// -----
280+
281+
llvm.func @nvvm_cvt_f6x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
282+
// expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f6x2 to f16x2.}}
283+
%res = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xf16>
284+
llvm.return
285+
}
286+
287+
// -----
288+
289+
llvm.func @nvvm_cvt_f4x2_to_f16x2_invalid_type(%src : i8) {
290+
// expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f4x2 to f16x2.}}
291+
%res = nvvm.convert.f4x2.to.f16x2 %src : i8 (f6E2M3FN) -> vector<2xf16>
292+
llvm.return
293+
}
294+
295+
// -----
296+
265297
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
266298
// expected-error @below {{cache eviction priority supported only for cache level L2}}
267299
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>

0 commit comments

Comments
 (0)