Skip to content

Commit 1837100

Browse files
committed
[MLIR][NVVM] Update convert Ops to use builtin types
This change updates the `convert.f32x2.to.f6x2`, `convert.f32x2.to.f8x2`, `convert.f16x2.to.f8x2`, and `convert.bf16x2.to.f8x2` Ops to use builtin types for the destination types as a `TypeAttr` instead of custom enums. The corresponding tests are updated to reflect the changes in the assembly format.
1 parent 9c3961f commit 1837100

File tree

5 files changed

+153
-144
lines changed

5 files changed

+153
-144
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 20 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
2121
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
2222
include "mlir/Interfaces/InferIntRangeInterface.td"
2323
include "mlir/Dialect/LLVMIR/LLVMTypes.td"
24+
include "mlir/IR/CommonAttrConstraints.td"
2425

2526
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
2627
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -1258,18 +1259,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
12581259
}];
12591260
}
12601261

1261-
def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
1262-
def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
1263-
1264-
def ConvertFP6Type : I32EnumAttr<"ConvertFP6Type", "NVVM ConvertFP6Type kind",
1265-
[ConvertFP6E2M3, ConvertFP6E3M2]> {
1266-
let genSpecializedAttr = 0;
1267-
let cppNamespace = "::mlir::NVVM";
1268-
}
1269-
def ConvertFP6TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP6Type, "convert_fp6_type"> {
1270-
let assemblyFormat = "`<` $value `>`";
1271-
}
1272-
12731262
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
12741263
let summary = "Convert a pair of float inputs to f6x2";
12751264
let description = [{
@@ -1290,19 +1279,20 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
12901279

12911280
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
12921281
let arguments = (ins
1293-
ConvertFP6TypeAttr:$type,
12941282
F32:$a,
12951283
F32:$b,
1296-
DefaultValuedAttr<BoolAttr, "false">:$relu);
1297-
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1284+
DefaultValuedAttr<BoolAttr, "false">:$relu,
1285+
TypeAttr:$dstTy);
1286+
let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
1287+
let hasVerifier = 1;
12981288

12991289
let extraClassDeclaration = [{
1300-
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP6Type,
1290+
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
13011291
bool hasRelu);
13021292
}];
13031293

13041294
string llvmBuilder = [{
1305-
auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($type, $relu);
1295+
auto intId = NVVM::ConvertF32x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
13061296
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
13071297
if(op.getDst().getType().isInteger(16))
13081298
$dst = packedI16;
@@ -1312,19 +1302,6 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
13121302
}];
13131303
}
13141304

1315-
def ConvertFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
1316-
def ConvertFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
1317-
def ConvertFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
1318-
1319-
def ConvertFP8Type : I32EnumAttr<"ConvertFP8Type", "NVVM ConvertFP8Type kind",
1320-
[ConvertFP8E4M3, ConvertFP8E5M2, ConvertFP8UE8M0]> {
1321-
let genSpecializedAttr = 0;
1322-
let cppNamespace = "::mlir::NVVM";
1323-
}
1324-
def ConvertFP8TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP8Type, "convert_fp8_type"> {
1325-
let assemblyFormat = "`<` $value `>`";
1326-
}
1327-
13281305
def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
13291306
let summary = "Convert a pair of float inputs to f8x2";
13301307
let description = [{
@@ -1346,23 +1323,23 @@ def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
13461323
let hasVerifier = 1;
13471324
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
13481325
let arguments = (ins
1349-
ConvertFP8TypeAttr:$type,
13501326
F32:$a,
13511327
F32:$b,
13521328
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
13531329
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1354-
DefaultValuedAttr<BoolAttr, "false">:$relu);
1355-
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1330+
DefaultValuedAttr<BoolAttr, "false">:$relu,
1331+
TypeAttr:$dstTy);
1332+
let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
13561333

13571334
let extraClassDeclaration = [{
1358-
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
1335+
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
13591336
NVVM::FPRoundingMode rnd,
13601337
NVVM::SaturationMode sat,
13611338
bool hasRelu);
13621339
}];
13631340

13641341
string llvmBuilder = [{
1365-
auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
1342+
auto intId = NVVM::ConvertF32x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
13661343
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
13671344
if(op.getDst().getType().isInteger(16))
13681345
$dst = packedI16;
@@ -1394,18 +1371,18 @@ def NVVM_ConvertF16x2ToF8x2Op : NVVM_Op<"convert.f16x2.to.f8x2"> {
13941371
let hasVerifier = 1;
13951372
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
13961373
let arguments = (ins
1397-
ConvertFP8TypeAttr:$type,
13981374
VectorOfLengthAndType<[2], [F16]>:$a,
1399-
DefaultValuedAttr<BoolAttr, "false">:$relu);
1400-
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1375+
DefaultValuedAttr<BoolAttr, "false">:$relu,
1376+
TypeAttr:$dstTy);
1377+
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
14011378

14021379
let extraClassDeclaration = [{
1403-
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP8Type to,
1380+
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
14041381
bool hasRelu);
14051382
}];
14061383

14071384
string llvmBuilder = [{
1408-
auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($type, $relu);
1385+
auto intId = NVVM::ConvertF16x2ToF8x2Op::getIntrinsicID($dstTy, $relu);
14091386
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
14101387
if(op.getDst().getType().isInteger(16))
14111388
$dst = packedI16;
@@ -1437,11 +1414,11 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
14371414
let hasVerifier = 1;
14381415
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
14391416
let arguments = (ins
1440-
ConvertFP8TypeAttr:$type,
14411417
VectorOfLengthAndType<[2], [BF16]>:$a,
14421418
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1443-
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
1444-
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1419+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1420+
TypeAttr:$dstTy);
1421+
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
14451422

14461423
let extraClassDeclaration = [{
14471424
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 90 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,14 @@ LogicalResult ConvertFloatToTF32Op::verify() {
189189
return success();
190190
}
191191

192+
LogicalResult ConvertF32x2ToF6x2Op::verify() {
193+
if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
194+
return emitError("Only f6E2M3FN and f6E3M2FN types are supported for "
195+
"ConvertF32x2ToF6x2Op.");
196+
}
197+
return success();
198+
}
199+
192200
LogicalResult ConvertF32x2ToF8x2Op::verify() {
193201
using RndMode = NVVM::FPRoundingMode;
194202
using SatMode = NVVM::SaturationMode;
@@ -200,41 +208,52 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
200208

201209
bool hasRelu = getRelu();
202210

203-
switch (getType()) {
204-
case ConvertFP8Type::E4M3:
205-
case ConvertFP8Type::E5M2:
206-
if (!isRoundingModeRN)
207-
return emitOpError("Only RN rounding mode is supported for conversions "
208-
"from f32x2 to .e4m3x2 or .e5m2x2 types");
209-
if (!isSatFinite)
210-
return emitOpError("Only SATFINITE saturation mode is supported for "
211-
"conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
212-
break;
213-
case ConvertFP8Type::UE8M0:
214-
if (!(isRoundingModeRZ || isRoundingModeRP))
215-
return emitOpError("Only RZ or RP rounding modes are supported for "
216-
"conversions from f32x2 to .ue8m0x2 type");
217-
if (hasRelu)
218-
return emitOpError("relu not supported for conversions to .ue8m0x2 type");
219-
break;
220-
}
221-
return success();
211+
return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
212+
.Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
213+
[&](mlir::Type) -> LogicalResult {
214+
if (!isRoundingModeRN) {
215+
return emitOpError(
216+
"Only RN rounding mode is supported for conversions from "
217+
"f32x2 to f8E4M3FNx2 or f8E5M2x2 types");
218+
}
219+
if (!isSatFinite) {
220+
return emitOpError(
221+
"Only SATFINITE saturation mode is supported for conversions "
222+
"from f32x2 to f8E4M3FNx2 or f8E5M2x2 types");
223+
}
224+
return success();
225+
})
226+
.Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
227+
if (!(isRoundingModeRZ || isRoundingModeRP)) {
228+
return emitOpError("Only RZ or RP rounding modes are supported for "
229+
"conversions from f32x2 to f8E8M0FNUx2 type");
230+
}
231+
if (hasRelu) {
232+
return emitOpError(
233+
"relu not supported for conversions to f8E8M0FNUx2 type");
234+
}
235+
return success();
236+
})
237+
.Default([this](mlir::Type) {
238+
return emitOpError("Only f8e4m3fn, f8e5m2, and f8e8m0fnu types are "
239+
"supported for conversions from f32x2 to f8x2");
240+
});
222241
}
223242

224243
LogicalResult ConvertF16x2ToF8x2Op::verify() {
225-
if (getType() == ConvertFP8Type::UE8M0)
226-
return emitOpError("Only .e4m3 or .e5m2 types are supported for "
244+
if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
245+
return emitOpError("Only f8E4M3FN or f8E5M2 types are supported for "
227246
"conversions from f16x2 to f8x2.");
228-
247+
}
229248
return success();
230249
}
231250

232251
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
233252
using RndMode = NVVM::FPRoundingMode;
234253

235-
if (getType() != ConvertFP8Type::UE8M0)
236-
return emitOpError(
237-
"Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
254+
if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
255+
return emitOpError("Only f8E8M0FNU type is supported for conversions from "
256+
"bf16x2 to f8x2.");
238257

239258
auto rnd = getRnd();
240259
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
@@ -1714,15 +1733,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
17141733
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
17151734
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
17161735

1717-
llvm::Intrinsic::ID
1718-
ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
1719-
switch (type) {
1720-
case NVVM::ConvertFP6Type::E2M3:
1721-
return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
1722-
case NVVM::ConvertFP6Type::E3M2:
1723-
return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
1724-
}
1725-
llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
1736+
llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
1737+
bool hasRelu) {
1738+
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
1739+
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
1740+
return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
1741+
})
1742+
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
1743+
return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
1744+
})
1745+
.Default([](mlir::Type) {
1746+
llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
1747+
return llvm::Intrinsic::not_intrinsic;
1748+
});
17261749
}
17271750

17281751
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
@@ -1734,41 +1757,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
17341757
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
17351758

17361759
llvm::Intrinsic::ID
1737-
ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
1738-
NVVM::FPRoundingMode rnd,
1760+
ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
17391761
NVVM::SaturationMode sat, bool hasRelu) {
17401762
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
17411763
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
17421764
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
17431765

1744-
switch (type) {
1745-
case NVVM::ConvertFP8Type::E4M3:
1746-
return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
1747-
case NVVM::ConvertFP8Type::E5M2:
1748-
return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
1749-
case NVVM::ConvertFP8Type::UE8M0:
1750-
if (hasRoundingModeRZ)
1751-
return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
1752-
else if (hasRoundingModeRP)
1753-
return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
1754-
}
1755-
llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1766+
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
1767+
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
1768+
return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
1769+
})
1770+
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
1771+
return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
1772+
})
1773+
.Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
1774+
if (hasRoundingModeRZ)
1775+
return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
1776+
else if (hasRoundingModeRP)
1777+
return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
1778+
1779+
llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
1780+
})
1781+
.Default([](mlir::Type) {
1782+
llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
1783+
return llvm::Intrinsic::not_intrinsic;
1784+
});
17561785
}
17571786

17581787
#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
17591788
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
17601789
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
17611790

1762-
llvm::Intrinsic::ID
1763-
ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
1764-
switch (type) {
1765-
case NVVM::ConvertFP8Type::E4M3:
1766-
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1767-
case NVVM::ConvertFP8Type::E5M2:
1768-
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1769-
default:
1770-
llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
1771-
}
1791+
llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
1792+
bool hasRelu) {
1793+
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
1794+
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
1795+
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1796+
})
1797+
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
1798+
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1799+
})
1800+
.Default([](mlir::Type) {
1801+
llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
1802+
return llvm::Intrinsic::not_intrinsic;
1803+
});
17721804
}
17731805

17741806
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \

mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
44
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
55
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
6-
%res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
6+
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN)
77
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
8-
%res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
8+
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
99
llvm.return
1010
}
1111

1212
// CHECK-LABEL: @convert_f32x2_to_fp6x2_vector
1313
llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
1414
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
1515
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
16-
%res1 = nvvm.convert.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
16+
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E2M3FN)
1717
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
1818
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
19-
%res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
19+
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN)
2020
llvm.return
2121
}

0 commit comments

Comments
 (0)