Skip to content

Commit 263896a

Browse files
authored
Merge pull request #366 from Xilinx/bump_to_d4f97da1
[AutoBump] Merge with fixes of d4f97da (Aug 27) (13)
2 parents 7b9c9c0 + 9a3518c commit 263896a

File tree

8 files changed

+65
-69
lines changed

8 files changed

+65
-69
lines changed

mlir/docs/Dialects/emitc.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ The following convention is followed:
1414
or any of the C++ headers in which the type is defined.
1515
* If `emitc.array` with a dimension of size zero, then the code
1616
requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html).
17+
* If `_Float16` is used, the code requires the support of C additional
18+
floating types.
19+
* If `__bf16` is used, the code requires a compiler that supports it, such as
20+
GCC or Clang.
1721
* Else the generated code is compatible with C99.
1822

1923
These restrictions are neither inherent to the EmitC dialect itself nor to the

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,21 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
116116
}
117117

118118
bool mlir::emitc::isSupportedFloatType(Type type) {
119-
return isa<Float32Type, Float64Type>(type);
119+
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
120+
switch (floatType.getWidth()) {
121+
case 16: {
122+
if (llvm::isa<Float16Type, BFloat16Type>(type))
123+
return true;
124+
return false;
125+
}
126+
case 32:
127+
case 64:
128+
return isa<Float32Type, Float64Type>(type);
129+
default:
130+
return false;
131+
}
132+
}
133+
return false;
120134
}
121135

122136
bool mlir::emitc::isFloatOrOpaqueType(Type type) {

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,12 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
13731373
val.toString(strValue, 0, 0, false);
13741374
os << strValue;
13751375
switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
1376+
case llvm::APFloatBase::S_IEEEhalf:
1377+
os << "f16";
1378+
break;
1379+
case llvm::APFloatBase::S_BFloat:
1380+
os << "bf16";
1381+
break;
13761382
case llvm::APFloatBase::S_IEEEsingle:
13771383
os << "f";
13781384
break;
@@ -1392,17 +1398,19 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
13921398

13931399
// Print floating point attributes.
13941400
if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1395-
if (!isa<Float32Type, Float64Type>(fAttr.getType())) {
1396-
return emitError(loc,
1397-
"expected floating point attribute to be f32 or f64");
1401+
if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1402+
fAttr.getType())) {
1403+
return emitError(
1404+
loc, "expected floating point attribute to be f16, bf16, f32 or f64");
13981405
}
13991406
printFloat(fAttr.getValue());
14001407
return success();
14011408
}
14021409
if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1403-
if (!isa<Float32Type, Float64Type>(dense.getElementType())) {
1404-
return emitError(loc,
1405-
"expected floating point attribute to be f32 or f64");
1410+
if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1411+
dense.getElementType())) {
1412+
return emitError(
1413+
loc, "expected floating point attribute to be f16, bf16, f32 or f64");
14061414
}
14071415
os << '{';
14081416
interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
@@ -1819,6 +1827,14 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
18191827
}
18201828
if (auto fType = dyn_cast<FloatType>(type)) {
18211829
switch (fType.getWidth()) {
1830+
case 16: {
1831+
if (llvm::isa<Float16Type>(type))
1832+
return (os << "_Float16"), success();
1833+
else if (llvm::isa<BFloat16Type>(type))
1834+
return (os << "__bf16"), success();
1835+
else
1836+
return emitError(loc, "cannot emit float type ") << type;
1837+
}
18221838
case 32:
18231839
return (os << "float"), success();
18241840
case 64:

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

Lines changed: 10 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -31,36 +31,35 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
3131
}
3232

3333
// -----
34-
35-
func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
34+
func.func @arith_cast_f80(%arg0: f80) -> i32 {
3635
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
37-
%t = arith.fptosi %arg0 : bf16 to i32
36+
%t = arith.fptosi %arg0 : f80 to i32
3837
return %t: i32
3938
}
4039

4140
// -----
4241

43-
func.func @arith_cast_f16(%arg0: f16) -> i32 {
42+
func.func @arith_cast_f128(%arg0: f128) -> i32 {
4443
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
45-
%t = arith.fptosi %arg0 : f16 to i32
44+
%t = arith.fptosi %arg0 : f128 to i32
4645
return %t: i32
4746
}
4847

4948

5049
// -----
5150

52-
func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
51+
func.func @arith_cast_to_f80(%arg0: i32) -> f80 {
5352
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
54-
%t = arith.sitofp %arg0 : i32 to bf16
55-
return %t: bf16
53+
%t = arith.sitofp %arg0 : i32 to f80
54+
return %t: f80
5655
}
5756

5857
// -----
5958

60-
func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
59+
func.func @arith_cast_to_f128(%arg0: i32) -> f128 {
6160
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
62-
%t = arith.sitofp %arg0 : i32 to f16
63-
return %t: f16
61+
%t = arith.sitofp %arg0 : i32 to f128
62+
return %t: f128
6463
}
6564

6665
// -----
@@ -143,23 +142,6 @@ func.func @arith_remui_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) -> vec
143142
return %divui: vector<5xi32>
144143
}
145144

146-
// -----
147-
148-
func.func @arith_extf_to_bf16(%arg0: f8E4M3FN) {
149-
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
150-
%ext = arith.extf %arg0 : f8E4M3FN to bf16
151-
return
152-
}
153-
154-
// -----
155-
156-
func.func @arith_extf_to_f16(%arg0: f8E4M3FN) {
157-
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
158-
%ext = arith.extf %arg0 : f8E4M3FN to f16
159-
return
160-
}
161-
162-
163145
// -----
164146

165147
func.func @arith_extf_to_tf32(%arg0: f8E4M3FN) {
@@ -202,22 +184,6 @@ func.func @arith_truncf_to_tf32(%arg0: f64) {
202184

203185
// -----
204186

205-
func.func @arith_truncf_to_f16(%arg0: f64) {
206-
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
207-
%trunc = arith.truncf %arg0 : f64 to f16
208-
return
209-
}
210-
211-
// -----
212-
213-
func.func @arith_truncf_to_bf16(%arg0: f64) {
214-
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
215-
%trunc = arith.truncf %arg0 : f64 to bf16
216-
return
217-
}
218-
219-
// -----
220-
221187
func.func @arith_truncf_to_f8E4M3FN(%arg0: f64) {
222188
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
223189
%trunc = arith.truncf %arg0 : f64 to f8E4M3FN

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ memref.global "nested" constant @nested_global : memref<3x7xf32>
4646

4747
// -----
4848

49-
func.func @unsupported_type_f16() {
49+
func.func @unsupported_type_f128() {
5050
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
51-
%0 = memref.alloca() : memref<4xf16>
51+
%0 = memref.alloca() : memref<4xf128>
5252
return
5353
}
5454

mlir/test/Dialect/EmitC/invalid_types.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,6 @@ func.func @illegal_f8E5M2FNUZ_type(%arg0: f8E5M2FNUZ, %arg1: f8E5M2FNUZ) {
160160

161161
// -----
162162

163-
func.func @illegal_f16_type(%arg0: f16, %arg1: f16) {
164-
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f16'}}
165-
%mul = "emitc.mul" (%arg0, %arg1) : (f16, f16) -> f16
166-
return
167-
}
168-
169-
// -----
170-
171-
func.func @illegal_bf16_type(%arg0: bf16, %arg1: bf16) {
172-
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'bf16'}}
173-
%mul = "emitc.mul" (%arg0, %arg1) : (bf16, bf16) -> bf16
174-
return
175-
}
176-
177-
// -----
178-
179163
func.func @illegal_f80_type(%arg0: f80, %arg1: f80) {
180164
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f80'}}
181165
%mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80

mlir/test/Target/Cpp/const.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ func.func @emitc_constant() {
1111
%c6 = "emitc.constant"(){value = 2 : index} : () -> index
1212
%c7 = "emitc.constant"(){value = 2.0 : f32} : () -> f32
1313
%f64 = "emitc.constant"(){value = 4.0 : f64} : () -> f64
14+
%f16 = "emitc.constant"(){value = 2.0 : f16} : () -> f16
15+
%bf16 = "emitc.constant"(){value = 4.0 : bf16} : () -> bf16
1416
%c8 = "emitc.constant"(){value = dense<0> : tensor<i32>} : () -> tensor<i32>
1517
%c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex>
1618
%c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
@@ -26,6 +28,8 @@ func.func @emitc_constant() {
2628
// CPP-DEFAULT-NEXT: size_t [[V6:[^ ]*]] = 2;
2729
// CPP-DEFAULT-NEXT: float [[V7:[^ ]*]] = 2.000000000e+00f;
2830
// CPP-DEFAULT-NEXT: double [[F64:[^ ]*]] = 4.00000000000000000e+00;
31+
// CPP-DEFAULT-NEXT: _Float16 [[F16:[^ ]*]] = 2.00000e+00f16;
32+
// CPP-DEFAULT-NEXT: __bf16 [[BF16:[^ ]*]] = 4.0000e+00bf16;
2933
// CPP-DEFAULT-NEXT: Tensor<int32_t> [[V8:[^ ]*]] = {0};
3034
// CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]] = {0, 1};
3135
// CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
@@ -40,6 +44,8 @@ func.func @emitc_constant() {
4044
// CPP-DECLTOP-NEXT: size_t [[V6:[^ ]*]];
4145
// CPP-DECLTOP-NEXT: float [[V7:[^ ]*]];
4246
// CPP-DECLTOP-NEXT: double [[F64:[^ ]*]];
47+
// CPP-DECLTOP-NEXT: _Float16 [[F16:[^ ]*]];
48+
// CPP-DECLTOP-NEXT: __bf16 [[BF16:[^ ]*]];
4349
// CPP-DECLTOP-NEXT: Tensor<int32_t> [[V8:[^ ]*]];
4450
// CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]];
4551
// CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]];
@@ -52,6 +58,8 @@ func.func @emitc_constant() {
5258
// CPP-DECLTOP-NEXT: [[V6]] = 2;
5359
// CPP-DECLTOP-NEXT: [[V7]] = 2.000000000e+00f;
5460
// CPP-DECLTOP-NEXT: [[F64]] = 4.00000000000000000e+00;
61+
// CPP-DECLTOP-NEXT: [[F16]] = 2.00000e+00f16;
62+
// CPP-DECLTOP-NEXT: [[BF16]] = 4.0000e+00bf16;
5563
// CPP-DECLTOP-NEXT: [[V8]] = {0};
5664
// CPP-DECLTOP-NEXT: [[V9]] = {0, 1};
5765
// CPP-DECLTOP-NEXT: [[V10]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};

mlir/test/Target/Cpp/types.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ func.func @ptr_types() {
3232
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i32>]} : () -> ()
3333
// CHECK-NEXT: f<int64_t*>();
3434
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i64>]} : () -> ()
35+
// CHECK-NEXT: f<_Float16*>();
36+
emitc.call_opaque "f"() {template_args = [!emitc.ptr<f16>]} : () -> ()
37+
// CHECK-NEXT: f<__bf16*>();
38+
emitc.call_opaque "f"() {template_args = [!emitc.ptr<bf16>]} : () -> ()
3539
// CHECK-NEXT: f<float*>();
3640
emitc.call_opaque "f"() {template_args = [!emitc.ptr<f32>]} : () -> ()
3741
// CHECK-NEXT: f<double*>();

0 commit comments

Comments
 (0)