Skip to content

Commit 12ceea5

Browse files
committed
[MLIR][NVVM] Add support for f32x2 to f4x2 conversion
This change adds the `convert.f32x2.to.f4x2` op to the NVVM Dialect for converting a pair of f32 values to an f4x2 (`e2m1x2`) value. PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
1 parent f0a787b commit 12ceea5

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,47 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
16541654
}];
16551655
}
16561656

1657+
def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
1658+
1659+
def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
1660+
[ConvertFP4E2M1]> {
1661+
let cppNamespace = "::mlir::NVVM";
1662+
}
1663+
def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type,
1664+
"convert_fp4_type"> {
1665+
let assemblyFormat = "`<` $value `>`";
1666+
}
1667+
1668+
def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
1669+
let summary = "Convert a pair of float inputs to f4x2";
1670+
let description = [{
1671+
This Op converts each of the given float inputs to the specified fp4 type.
1672+
The result `dst` is returned as an i8 type where the converted values are
1673+
packed such that the value converted from `a` is stored in the upper 4 bits
1674+
of `dst` and the value converted from `b` is stored in the lower 4 bits of
1675+
`dst`.
1676+
The `relu` attribute, when set, lowers to the '.relu' variant of
1677+
the cvt instruction.
1678+
1679+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1680+
}];
1681+
1682+
let results = (outs I8:$dst);
1683+
let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
1684+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1685+
let assemblyFormat = "$type $a `,` $b attr-dict";
1686+
1687+
let extraClassDeclaration = [{
1688+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
1689+
}];
1690+
1691+
string llvmBuilder = [{
1692+
auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
1693+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1694+
$dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
1695+
}];
1696+
}
1697+
16571698
def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
16581699
def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
16591700

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,6 +1976,12 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
19761976
}
19771977
}
19781978

1979+
llvm::Intrinsic::ID
1980+
ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
1981+
return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
1982+
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
1983+
}
1984+
19791985
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
19801986
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
19811987
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
4+
llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
5+
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
6+
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
7+
%res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
8+
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
9+
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
10+
%res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
11+
llvm.return
12+
}

0 commit comments

Comments
 (0)