Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
This Op converts each of the given float inputs to the specified fp6 type.
Expand All @@ -1096,6 +1096,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];

let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
CVTFP6TypeAttr:$type,
Expand All @@ -1110,7 +1111,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
}];

string llvmBuilder = [{
auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
auto intId = NVVM::CvtF32x2ToF6x2Op::getIntrinsicID($type, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
Expand All @@ -1120,6 +1121,153 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
}];
}

def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
Comment on lines +1124 to +1126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating type for each op won't really scale. I think should use existing types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense. I did it this way here since the other existing type enums were prefixed with the corresponding Op name. But we should probably unify all of them and rename them with an NVVMFP prefix instead perhaps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes Guray, I brought this up on the first fp6 cvt Op review itself (more from a re-use perspective, though)

With a unified enum (let's say, for all the FP types), we may need to update/tighten the verifiers of many Ops to error out on unsupported types. Please let us know your thoughts on this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need enum embedded in NVVM dialect? I'm asking can we just reuse existing MLIR builtin types. At this point, I assume we have all the exotic types. What do you think?

Copy link
Contributor Author

@Wolfram70 Wolfram70 May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried using the builtin types ( f8e4m3fn, f8e5m2, and f8e8m0fnu) for these Ops but ran into issues during lowering to LLVMIR and bitcasting the vector to a packed i16 for the intrinsics since it looks like the vectors of these types cannot be constructed/are supported and an assertion fails due to mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp:821:isCompatibleVectorType().


def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
[CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
let summary = "Convert a pair of float inputs to f8x2";
let description = [{
This Op converts each of the given float inputs to the specified fp8 type.
The result `dst` is represented as an i16 type or as a vector
of two i8 types.
If `dst` is returned as an i16 type, the converted values are packed such
that the value converted from `a` is stored in the upper 8 bits of `dst`
and the value converted from `b` is stored in the lower 8 bits of `dst`.
If `dst` is returned as a vector type, each converted value is stored as an
i8 element in the vector.
The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];

let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
CVTFP8TypeAttr:$type,
F32:$a,
F32:$b,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu);
}];

string llvmBuilder = [{
auto intId = NVVM::CvtF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
else
$dst = builder.CreateBitCast(packedI16,
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
}];
}

def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
let summary = "Convert an f16x2 input to f8x2";
let description = [{
This Op converts the given f16 inputs in an f16x2 vector to the specified
f8 type.
The result `dst` is represented as an i16 type or as a vector
of two i8 types.
If `dst` is returned as an i16 type, the converted values from `a`
are packed such that the value converted from the first element of `a`
is stored in the upper 8 bits of `dst` and the value converted from the
second element of `a` is stored in the lower 8 bits of `dst`.
If `dst` is returned as a vector type, each converted value is stored as an
i8 element in the vector.
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];

let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
CVTFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [F16]>:$a,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
bool hasRelu);
}];

string llvmBuilder = [{
auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
else
$dst = builder.CreateBitCast(packedI16,
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
}];
}

def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
let summary = "Convert a pair of bf16 inputs to f8x2";
let description = [{
This Op converts the given bf16 inputs in a bf16x2 vector to the specified
f8 type.
The result `dst` is represented as an i16 type or as a vector
of two i8 types.
If `dst` is returned as an i16 type, the converted values from `a`
are packed such that the value converted from the first element of `a`
is stored in the upper 8 bits of `dst` and the value converted from the
second element of `a` is stored in the lower 8 bits of `dst`.
If `dst` is returned as a vector type, each converted value is stored as an
i8 element in the vector.
The `rnd` and `sat` attributes specify the rounding and saturation modes
respectively.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];

let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
CVTFP8TypeAttr:$type,
VectorOfLengthAndType<[2], [BF16]>:$a,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat);
}];

string llvmBuilder = [{
auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
else
$dst = builder.CreateBitCast(packedI16,
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
}];
}

//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
Expand Down
129 changes: 124 additions & 5 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,61 @@ LogicalResult CvtFloatToTF32Op::verify() {
return success();
}

LogicalResult CvtF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;

bool isRoundingModeRN = getRnd() == RndMode::RN;
bool isRoundingModeRZ = getRnd() == RndMode::RZ;
bool isRoundingModeRP = getRnd() == RndMode::RP;
bool isSatFinite = getSat() == SatMode::SATFINITE;

bool hasRelu = getRelu();

switch (getType()) {
case CVTFP8Type::E4M3:
case CVTFP8Type::E5M2:
if (!isRoundingModeRN)
return emitOpError("Only RN rounding mode is supported for conversions "
"from f32x2 to .e4m3x2 or .e5m2x2 types");
if (!isSatFinite)
return emitOpError("Only SATFINITE saturation mode is supported for "
"conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
break;
case CVTFP8Type::UE8M0:
if (!(isRoundingModeRZ || isRoundingModeRP))
return emitOpError("Only RZ or RP rounding modes are supported for "
"conversions from f32x2 to .ue8m0x2 type");
if (hasRelu)
return emitOpError("relu not supported for conversions to .ue8m0x2 type");
break;
}
return success();
}

LogicalResult CvtF16x2ToF8x2Op::verify() {
if (getType() == CVTFP8Type::UE8M0)
return emitOpError("Only .e4m3 or .e5m2 types are supported for "
"conversions from f16x2 to f8x2.");

return success();
}

LogicalResult CvtBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;

if (getType() != CVTFP8Type::UE8M0)
return emitOpError(
"Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");

auto rnd = getRnd();
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
return emitOpError("Only RZ and RP rounding modes are supported for "
"conversions from bf16x2 to f8x2.");

return success();
}

LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
Expand Down Expand Up @@ -1290,17 +1345,81 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}

#define CVT_TO_F6X2_ID_IMPL(type, has_relu) \
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite

llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
bool hasRelu) {
llvm::Intrinsic::ID CvtF32x2ToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
bool hasRelu) {
switch (type) {
case NVVM::CVTFP6Type::E2M3:
return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu);
return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
case NVVM::CVTFP6Type::E3M2:
return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu);
return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
}
}

#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd

#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn

llvm::Intrinsic::ID CvtF32x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);

switch (type) {
case NVVM::CVTFP8Type::E4M3:
return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
case NVVM::CVTFP8Type::E5M2:
return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
case NVVM::CVTFP8Type::UE8M0:
if (hasRoundingModeRZ)
return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
else if (hasRoundingModeRP)
return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
}
llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
}

#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn

llvm::Intrinsic::ID CvtF16x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
bool hasRelu) {
switch (type) {
case NVVM::CVTFP8Type::E4M3:
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
case NVVM::CVTFP8Type::E5M2:
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
default:
llvm_unreachable("Invalid CVTFP8Type for CvtF16x2ToF8x2Op");
}
}

#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd

llvm::Intrinsic::ID
CvtBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
switch (rnd) {
case NVVM::FPRoundingMode::RZ:
return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
case NVVM::FPRoundingMode::RP:
return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
default:
llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
}
}

Expand Down
17 changes: 8 additions & 9 deletions mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// CHECK-LABEL: @convert_float_to_fp6x2_packed
llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
%res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
%res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
llvm.return
}

// CHECK-LABEL: @convert_float_to_fp6x2_vector
llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
// CHECK-LABEL: @convert_f32x2_to_fp6x2_vector
llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
%res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
%res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
llvm.return
}

Loading