Skip to content

Commit e98de2e

Browse files
authored
[MLIR][NVVM] Add support for f32x2 to f4x2 conversion (llvm#162273)
This change adds the `convert.f32x2.to.f4x2` op to the NVVM Dialect for converting a pair of f32 values to an f4x2 (`e2m1x2`) value. PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
1 parent de9013f commit e98de2e

File tree

4 files changed

+82
-0
lines changed

4 files changed

+82
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,40 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
16551655
}];
16561656
}
16571657

1658+
def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
1659+
let summary = "Convert a pair of float inputs to f4x2";
1660+
let description = [{
1661+
This Op converts each of the given float inputs to the specified fp4 type.
1662+
The result `dst` is returned as an i8 type where the converted values are
1663+
packed such that the value converted from `a` is stored in the upper 4 bits
1664+
of `dst` and the value converted from `b` is stored in the lower 4 bits of
1665+
`dst`.
1666+
The `relu` attribute, when set, lowers to the '.relu' variant of
1667+
the cvt instruction.
1668+
1669+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1670+
}];
1671+
1672+
let results = (outs I8:$dst);
1673+
let arguments = (ins F32:$a, F32:$b,
1674+
DefaultValuedAttr<BoolAttr, "false">:$relu,
1675+
TypeAttr:$dstTy);
1676+
let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
1677+
let hasVerifier = 1;
1678+
1679+
let extraClassDeclaration = [{
1680+
static mlir::NVVM::IDArgPair
1681+
getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
1682+
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
1683+
}];
1684+
1685+
string llvmBuilder = [{
1686+
auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
1687+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
1688+
$dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
1689+
}];
1690+
}
1691+
16581692
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
16591693
let summary = "Convert a pair of float inputs to f6x2";
16601694
let description = [{

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
309309
return success();
310310
}
311311

312+
LogicalResult ConvertF32x2ToF4x2Op::verify() {
313+
mlir::MLIRContext *ctx = getContext();
314+
315+
if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
316+
return emitOpError("Only ")
317+
<< mlir::Float4E2M1FNType::get(ctx)
318+
<< " type is supported for conversions from f32x2 to f4x2.";
319+
320+
return success();
321+
}
322+
312323
LogicalResult BulkStoreOp::verify() {
313324
if (getInitVal() != 0)
314325
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2047,6 +2058,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
20472058
}
20482059
}
20492060

2061+
NVVM::IDArgPair
2062+
ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
2063+
LLVM::ModuleTranslation &mt,
2064+
llvm::IRBuilderBase &builder) {
2065+
llvm::SmallVector<llvm::Value *> args;
2066+
args.push_back(mt.lookupValue(op.getA()));
2067+
args.push_back(mt.lookupValue(op.getB()));
2068+
2069+
bool hasRelu = op.getRelu();
2070+
2071+
llvm::Intrinsic::ID intId =
2072+
hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
2073+
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
2074+
2075+
return {intId, std::move(args)};
2076+
}
2077+
20502078
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
20512079
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
20522080
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
4+
llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
5+
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
6+
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
7+
%res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
8+
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
9+
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
10+
%res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
11+
llvm.return
12+
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,14 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
254254

255255
// -----
256256

257+
llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
258+
// expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
259+
%res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
260+
llvm.return
261+
}
262+
263+
// -----
264+
257265
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
258266
// expected-error @below {{cache eviction priority supported only for cache level L2}}
259267
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>

0 commit comments

Comments
 (0)