Skip to content

Commit d4f97da

Browse files
authored
[mlir] Support emit fp16 and bf16 type to cpp (llvm#105803)
1 parent 4b7f07a commit d4f97da

File tree

7 files changed

+55
-19
lines changed

7 files changed

+55
-19
lines changed

mlir/docs/Dialects/emitc.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ The following convention is followed:
1212
operation, C++20 is required.
1313
* If `ssize_t` is used, then the code requires the POSIX header `sys/types.h`
1414
or any of the C++ headers in which the type is defined.
15+
* If `_Float16` is used, the code requires the support of C additional
16+
floating types.
17+
* If `__bf16` is used, the code requires a compiler that supports it, such as
18+
GCC or Clang.
1519
* Else the generated code is compatible with C99.
1620

1721
These restrictions are neither inherent to the EmitC dialect itself nor to the

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
116116
bool mlir::emitc::isSupportedFloatType(Type type) {
117117
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
118118
switch (floatType.getWidth()) {
119+
case 16: {
120+
if (llvm::isa<Float16Type, BFloat16Type>(type))
121+
return true;
122+
return false;
123+
}
119124
case 32:
120125
case 64:
121126
return true;

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,12 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
12581258
val.toString(strValue, 0, 0, false);
12591259
os << strValue;
12601260
switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
1261+
case llvm::APFloatBase::S_IEEEhalf:
1262+
os << "f16";
1263+
break;
1264+
case llvm::APFloatBase::S_BFloat:
1265+
os << "bf16";
1266+
break;
12611267
case llvm::APFloatBase::S_IEEEsingle:
12621268
os << "f";
12631269
break;
@@ -1277,17 +1283,19 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
12771283

12781284
// Print floating point attributes.
12791285
if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1280-
if (!isa<Float32Type, Float64Type>(fAttr.getType())) {
1281-
return emitError(loc,
1282-
"expected floating point attribute to be f32 or f64");
1286+
if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1287+
fAttr.getType())) {
1288+
return emitError(
1289+
loc, "expected floating point attribute to be f16, bf16, f32 or f64");
12831290
}
12841291
printFloat(fAttr.getValue());
12851292
return success();
12861293
}
12871294
if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1288-
if (!isa<Float32Type, Float64Type>(dense.getElementType())) {
1289-
return emitError(loc,
1290-
"expected floating point attribute to be f32 or f64");
1295+
if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1296+
dense.getElementType())) {
1297+
return emitError(
1298+
loc, "expected floating point attribute to be f16, bf16, f32 or f64");
12911299
}
12921300
os << '{';
12931301
interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
@@ -1640,6 +1648,14 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
16401648
}
16411649
if (auto fType = dyn_cast<FloatType>(type)) {
16421650
switch (fType.getWidth()) {
1651+
case 16: {
1652+
if (llvm::isa<Float16Type>(type))
1653+
return (os << "_Float16"), success();
1654+
else if (llvm::isa<BFloat16Type>(type))
1655+
return (os << "__bf16"), success();
1656+
else
1657+
return emitError(loc, "cannot emit float type ") << type;
1658+
}
16431659
case 32:
16441660
return (os << "float"), success();
16451661
case 64:

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,35 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
1515
}
1616

1717
// -----
18-
19-
func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
18+
func.func @arith_cast_f80(%arg0: f80) -> i32 {
2019
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
21-
%t = arith.fptosi %arg0 : bf16 to i32
20+
%t = arith.fptosi %arg0 : f80 to i32
2221
return %t: i32
2322
}
2423

2524
// -----
2625

27-
func.func @arith_cast_f16(%arg0: f16) -> i32 {
26+
func.func @arith_cast_f128(%arg0: f128) -> i32 {
2827
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
29-
%t = arith.fptosi %arg0 : f16 to i32
28+
%t = arith.fptosi %arg0 : f128 to i32
3029
return %t: i32
3130
}
3231

3332

3433
// -----
3534

36-
func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
35+
func.func @arith_cast_to_f80(%arg0: i32) -> f80 {
3736
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
38-
%t = arith.sitofp %arg0 : i32 to bf16
39-
return %t: bf16
37+
%t = arith.sitofp %arg0 : i32 to f80
38+
return %t: f80
4039
}
4140

4241
// -----
4342

44-
func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
43+
func.func @arith_cast_to_f128(%arg0: i32) -> f128 {
4544
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
46-
%t = arith.sitofp %arg0 : i32 to f16
47-
return %t: f16
45+
%t = arith.sitofp %arg0 : i32 to f128
46+
return %t: f128
4847
}
4948

5049
// -----

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ memref.global "nested" constant @nested_global : memref<3x7xf32>
4646

4747
// -----
4848

49-
func.func @unsupported_type_f16() {
49+
func.func @unsupported_type_f128() {
5050
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
51-
%0 = memref.alloca() : memref<4xf16>
51+
%0 = memref.alloca() : memref<4xf128>
5252
return
5353
}
5454

mlir/test/Target/Cpp/const.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ func.func @emitc_constant() {
1111
%c6 = "emitc.constant"(){value = 2 : index} : () -> index
1212
%c7 = "emitc.constant"(){value = 2.0 : f32} : () -> f32
1313
%f64 = "emitc.constant"(){value = 4.0 : f64} : () -> f64
14+
%f16 = "emitc.constant"(){value = 2.0 : f16} : () -> f16
15+
%bf16 = "emitc.constant"(){value = 4.0 : bf16} : () -> bf16
1416
%c8 = "emitc.constant"(){value = dense<0> : tensor<i32>} : () -> tensor<i32>
1517
%c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex>
1618
%c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
@@ -26,6 +28,8 @@ func.func @emitc_constant() {
2628
// CPP-DEFAULT-NEXT: size_t [[V6:[^ ]*]] = 2;
2729
// CPP-DEFAULT-NEXT: float [[V7:[^ ]*]] = 2.000000000e+00f;
2830
// CPP-DEFAULT-NEXT: double [[F64:[^ ]*]] = 4.00000000000000000e+00;
31+
// CPP-DEFAULT-NEXT: _Float16 [[F16:[^ ]*]] = 2.00000e+00f16;
32+
// CPP-DEFAULT-NEXT: __bf16 [[BF16:[^ ]*]] = 4.0000e+00bf16;
2933
// CPP-DEFAULT-NEXT: Tensor<int32_t> [[V8:[^ ]*]] = {0};
3034
// CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]] = {0, 1};
3135
// CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
@@ -40,6 +44,8 @@ func.func @emitc_constant() {
4044
// CPP-DECLTOP-NEXT: size_t [[V6:[^ ]*]];
4145
// CPP-DECLTOP-NEXT: float [[V7:[^ ]*]];
4246
// CPP-DECLTOP-NEXT: double [[F64:[^ ]*]];
47+
// CPP-DECLTOP-NEXT: _Float16 [[F16:[^ ]*]];
48+
// CPP-DECLTOP-NEXT: __bf16 [[BF16:[^ ]*]];
4349
// CPP-DECLTOP-NEXT: Tensor<int32_t> [[V8:[^ ]*]];
4450
// CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]];
4551
// CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]];
@@ -52,6 +58,8 @@ func.func @emitc_constant() {
5258
// CPP-DECLTOP-NEXT: [[V6]] = 2;
5359
// CPP-DECLTOP-NEXT: [[V7]] = 2.000000000e+00f;
5460
// CPP-DECLTOP-NEXT: [[F64]] = 4.00000000000000000e+00;
61+
// CPP-DECLTOP-NEXT: [[F16]] = 2.00000e+00f16;
62+
// CPP-DECLTOP-NEXT: [[BF16]] = 4.0000e+00bf16;
5563
// CPP-DECLTOP-NEXT: [[V8]] = {0};
5664
// CPP-DECLTOP-NEXT: [[V9]] = {0, 1};
5765
// CPP-DECLTOP-NEXT: [[V10]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};

mlir/test/Target/Cpp/types.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ func.func @ptr_types() {
2222
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i32>]} : () -> ()
2323
// CHECK-NEXT: f<int64_t*>();
2424
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i64>]} : () -> ()
25+
// CHECK-NEXT: f<_Float16*>();
26+
emitc.call_opaque "f"() {template_args = [!emitc.ptr<f16>]} : () -> ()
27+
// CHECK-NEXT: f<__bf16*>();
28+
emitc.call_opaque "f"() {template_args = [!emitc.ptr<bf16>]} : () -> ()
2529
// CHECK-NEXT: f<float*>();
2630
emitc.call_opaque "f"() {template_args = [!emitc.ptr<f32>]} : () -> ()
2731
// CHECK-NEXT: f<double*>();

0 commit comments

Comments
 (0)