Skip to content

Commit d1464bb

Browse files
committed
[MLIR][NVVM] Update support for conversions to f8x2 and f6x2 types
This change: - Adds the `cvt.float.to.f6x2`, `cvt.f16x2.to.f8x2`, and `cvt.bf16x2.to.f8x2` Ops to the NVVM dialect for the conversions to `.e4m3x2`, `e5m2x2`, and `.ue8m0x2` types. - Renames the recently added `cvt.to.f6x2` Op to `cvt.float.to.f6x2` for consistency with the other conversion Ops. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
1 parent fa1fe11 commit d1464bb

File tree

5 files changed

+456
-12
lines changed

5 files changed

+456
-12
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 149 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
10791079
let assemblyFormat = "`<` $value `>`";
10801080
}
10811081

1082-
def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1082+
def NVVM_CvtFloatToF6x2Op : NVVM_Op<"cvt.float.to.f6x2"> {
10831083
let summary = "Convert a pair of float inputs to f6x2";
10841084
let description = [{
10851085
This Op converts each of the given float inputs to the specified fp6 type.
@@ -1110,7 +1110,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
11101110
}];
11111111

11121112
string llvmBuilder = [{
1113-
auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
1113+
auto intId = NVVM::CvtFloatToF6x2Op::getIntrinsicID($type, $relu);
11141114
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
11151115
if(op.getDst().getType().isInteger(16))
11161116
$dst = packedI16;
@@ -1120,6 +1120,153 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
11201120
}];
11211121
}
11221122

1123+
def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
1124+
def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
1125+
def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
1126+
1127+
def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
1128+
[CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
1129+
let genSpecializedAttr = 0;
1130+
let cppNamespace = "::mlir::NVVM";
1131+
}
1132+
def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
1133+
let assemblyFormat = "`<` $value `>`";
1134+
}
1135+
1136+
def NVVM_CvtFloatToF8x2Op : NVVM_Op<"cvt.float.to.f8x2"> {
1137+
let summary = "Convert a pair of float inputs to f8x2";
1138+
let description = [{
1139+
This Op converts the given f32 inputs to f8x2.
1140+
The result `dst` is represented as an i16 type or as a vector
1141+
of two i8 types.
1142+
If `dst` is returned as an i16 type, the converted values are packed such
1143+
that the value converted from `a` is stored in the upper 8 bits of `dst`
1144+
and the value converted from `b` is stored in the lower 8 bits of `dst`.
1145+
If `dst` is returned as a vector type, each converted value is stored as an
1146+
i8 element in the vector.
1147+
The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
1148+
The `relu` attribute, when set, lowers to the '.relu' variant of
1149+
the cvt instruction.
1150+
1151+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1152+
}];
1153+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1154+
let arguments = (ins
1155+
CVTFP8TypeAttr:$type,
1156+
F32:$a,
1157+
F32:$b,
1158+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1159+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1160+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1161+
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1162+
1163+
let extraClassDeclaration = [{
1164+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1165+
NVVM::FPRoundingMode rnd,
1166+
NVVM::SaturationMode sat,
1167+
bool hasRelu);
1168+
}];
1169+
1170+
string llvmBuilder = [{
1171+
auto intId = NVVM::CvtFloatToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
1172+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1173+
if(op.getDst().getType().isInteger(16))
1174+
$dst = packedI16;
1175+
else
1176+
$dst = builder.CreateBitCast(packedI16,
1177+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1178+
}];
1179+
1180+
let hasVerifier = 1;
1181+
}
1182+
1183+
def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
1184+
let summary = "Convert an f16x2 input to f8x2";
1185+
let description = [{
1186+
This Op converts the given f16 inputs in an f16x2 vector to the specified
1187+
f8 type.
1188+
The result `dst` is represented as an i16 type or as a vector
1189+
of two i8 types.
1190+
If `dst` is returned as an i16 type, the converted values from `a`
1191+
are packed such that the value converted from the first element of `a`
1192+
is stored in the upper 8 bits of `dst` and the value converted from the
1193+
second element of `a` is stored in the lower 8 bits of `dst`.
1194+
If `dst` is returned as a vector type, each converted value is stored as an
1195+
i8 element in the vector.
1196+
The `relu` attribute, when set, lowers to the '.relu' variant of
1197+
the cvt instruction.
1198+
1199+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1200+
}];
1201+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1202+
let arguments = (ins
1203+
CVTFP8TypeAttr:$type,
1204+
VectorOfLengthAndType<[2], [F16]>:$a,
1205+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1206+
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1207+
1208+
let extraClassDeclaration = [{
1209+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1210+
bool hasRelu);
1211+
}];
1212+
1213+
string llvmBuilder = [{
1214+
auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu);
1215+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1216+
if(op.getDst().getType().isInteger(16))
1217+
$dst = packedI16;
1218+
else
1219+
$dst = builder.CreateBitCast(packedI16,
1220+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1221+
}];
1222+
1223+
let hasVerifier = 1;
1224+
}
1225+
1226+
def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
1227+
let summary = "Convert a pair of bf16 inputs to f8x2";
1228+
let description = [{
1229+
This Op converts the given bf16 inputs in a bf16x2 vector to the specified
1230+
f8 type.
1231+
The result `dst` is represented as an i16 type or as a vector
1232+
of two i8 types.
1233+
If `dst` is returned as an i16 type, the converted values from `a`
1234+
are packed such that the value converted from the first element of `a`
1235+
is stored in the upper 8 bits of `dst` and the value converted from the
1236+
second element of `a` is stored in the lower 8 bits of `dst`.
1237+
If `dst` is returned as a vector type, each converted value is stored as an
1238+
i8 element in the vector.
1239+
The `rnd` and `sat` attributes specify the rounding and saturation modes
1240+
respectively.
1241+
1242+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1243+
}];
1244+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1245+
let arguments = (ins
1246+
CVTFP8TypeAttr:$type,
1247+
VectorOfLengthAndType<[2], [BF16]>:$a,
1248+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1249+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
1250+
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1251+
1252+
let extraClassDeclaration = [{
1253+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
1254+
NVVM::SaturationMode sat);
1255+
}];
1256+
1257+
string llvmBuilder = [{
1258+
auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
1259+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1260+
if(op.getDst().getType().isInteger(16))
1261+
$dst = packedI16;
1262+
else
1263+
$dst = builder.CreateBitCast(packedI16,
1264+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1265+
}];
1266+
1267+
let hasVerifier = 1;
1268+
}
1269+
11231270
//===----------------------------------------------------------------------===//
11241271
// NVVM MMA Ops
11251272
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,66 @@ LogicalResult CvtFloatToTF32Op::verify() {
133133
return success();
134134
}
135135

136+
LogicalResult CvtFloatToF8x2Op::verify() {
137+
using RndMode = NVVM::FPRoundingMode;
138+
using SatMode = NVVM::SaturationMode;
139+
140+
bool isRoundingModeRN = getRnd() == RndMode::RN;
141+
bool isRoundingModeRZ = getRnd() == RndMode::RZ;
142+
bool isRoundingModeRP = getRnd() == RndMode::RP;
143+
bool isSatFinite = getSat() == SatMode::SATFINITE;
144+
145+
bool hasRelu = getRelu();
146+
147+
switch (getType()) {
148+
case CVTFP8Type::E4M3:
149+
case CVTFP8Type::E5M2:
150+
if (!(isRoundingModeRN && isSatFinite))
151+
return emitOpError(
152+
"Only RN rounding mode and SATFINITE saturation mode "
153+
"are supported for conversions to .e4m3x2 or .e5m2x2 types from f32");
154+
break;
155+
case CVTFP8Type::UE8M0:
156+
if (!(isRoundingModeRZ || isRoundingModeRP))
157+
return emitOpError("Only RZ or RP rounding modes are supported for "
158+
"conversions to .ue8m0x2 type from f32");
159+
if (hasRelu)
160+
return emitOpError("relu not supported for conversions to .ue8m0x2 type");
161+
break;
162+
}
163+
return success();
164+
}
165+
166+
LogicalResult CvtF16x2ToF8x2Op::verify() {
167+
switch (getType()) {
168+
case CVTFP8Type::E4M3:
169+
case CVTFP8Type::E5M2:
170+
break;
171+
case CVTFP8Type::UE8M0:
172+
return emitOpError("Only .e4m3 or .e5m2 types are supported for "
173+
"conversions from f16x2 to f8x2.");
174+
}
175+
return success();
176+
}
177+
178+
LogicalResult CvtBF16x2ToF8x2Op::verify() {
179+
using RndMode = NVVM::FPRoundingMode;
180+
181+
if (getType() != CVTFP8Type::UE8M0)
182+
return emitOpError(
183+
"Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
184+
185+
switch (getRnd()) {
186+
case RndMode::RZ:
187+
case RndMode::RP:
188+
break;
189+
default:
190+
return emitOpError("Only RZ and RP rounding modes are supported for "
191+
"conversions from bf16x2 to f8x2.");
192+
}
193+
return success();
194+
}
195+
136196
LogicalResult BulkStoreOp::verify() {
137197
if (getInitVal() != 0)
138198
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1290,17 +1350,81 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12901350
}
12911351
}
12921352

1293-
#define CVT_TO_F6X2_ID_IMPL(type, has_relu) \
1353+
#define GET_FLOAT_TO_F6x2_ID(type, has_relu) \
12941354
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
12951355
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
12961356

1297-
llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
1298-
bool hasRelu) {
1357+
llvm::Intrinsic::ID CvtFloatToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
1358+
bool hasRelu) {
12991359
switch (type) {
13001360
case NVVM::CVTFP6Type::E2M3:
1301-
return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu);
1361+
return GET_FLOAT_TO_F6x2_ID(e2m3x2, hasRelu);
13021362
case NVVM::CVTFP6Type::E3M2:
1303-
return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu);
1363+
return GET_FLOAT_TO_F6x2_ID(e3m2x2, hasRelu);
1364+
}
1365+
}
1366+
1367+
#define GET_FLOAT_TO_F8X2_US_ID(rnd, has_satf) \
1368+
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1369+
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1370+
1371+
#define GET_FLOAT_TO_F8X2_S_ID(type, has_relu) \
1372+
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1373+
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
1374+
1375+
llvm::Intrinsic::ID CvtFloatToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
1376+
NVVM::FPRoundingMode rnd,
1377+
NVVM::SaturationMode sat,
1378+
bool hasRelu) {
1379+
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1380+
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1381+
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1382+
1383+
switch (type) {
1384+
case NVVM::CVTFP8Type::E4M3:
1385+
return GET_FLOAT_TO_F8X2_S_ID(e4m3x2, hasRelu);
1386+
case NVVM::CVTFP8Type::E5M2:
1387+
return GET_FLOAT_TO_F8X2_S_ID(e5m2x2, hasRelu);
1388+
case NVVM::CVTFP8Type::UE8M0:
1389+
if (hasRoundingModeRZ)
1390+
return GET_FLOAT_TO_F8X2_US_ID(rz, hasSatFinite);
1391+
else if (hasRoundingModeRP)
1392+
return GET_FLOAT_TO_F8X2_US_ID(rp, hasSatFinite);
1393+
}
1394+
llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1395+
}
1396+
1397+
#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1398+
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1399+
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1400+
1401+
llvm::Intrinsic::ID CvtF16x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
1402+
bool hasRelu) {
1403+
switch (type) {
1404+
case NVVM::CVTFP8Type::E4M3:
1405+
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1406+
case NVVM::CVTFP8Type::E5M2:
1407+
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1408+
default:
1409+
llvm_unreachable("Invalid CVTFP8Type for CvtF16x2ToF8x2Op");
1410+
}
1411+
}
1412+
1413+
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1414+
has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1415+
: llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1416+
1417+
llvm::Intrinsic::ID
1418+
CvtBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1419+
NVVM::SaturationMode sat) {
1420+
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1421+
switch (rnd) {
1422+
case NVVM::FPRoundingMode::RZ:
1423+
return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
1424+
case NVVM::FPRoundingMode::RP:
1425+
return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
1426+
default:
1427+
llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
13041428
}
13051429
}
13061430

mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,19 @@
33
// CHECK-LABEL: @convert_float_to_fp6x2_packed
44
llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
55
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
6-
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
6+
%res1 = nvvm.cvt.float.to.f6x2 <e2m3> %srcA, %srcB : i16
77
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
8-
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
8+
%res2 = nvvm.cvt.float.to.f6x2 <e3m2> %srcA, %srcB : i16
99
llvm.return
1010
}
1111

1212
// CHECK-LABEL: @convert_float_to_fp6x2_vector
1313
llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
1414
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
1515
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
16-
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
16+
%res1 = nvvm.cvt.float.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
1717
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
1818
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
19-
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
19+
%res2 = nvvm.cvt.float.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
2020
llvm.return
2121
}
22-

0 commit comments

Comments
 (0)