Skip to content

Commit 6422546

Browse files
[mlir][LLVM] Fix conversion of non-standard MLIR float types (#122634)
Certain non-standard float types were directly passed through in the LLVM type converter, resulting in invalid IR or failed assertions: ``` mlir-opt: mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp:638: FailureOr<Type> mlir::LLVMTypeConverter::convertVectorType(VectorType) const: Assertion `LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"' failed. ``` The LLVM type converter should not define invalid type conversion rules for such types. If there is no type conversion rule, conversion patterns will not apply to ops with such operand types.
1 parent 7532958 commit 6422546

File tree

3 files changed

+62
-1
lines changed

3 files changed

+62
-1
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,21 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
294294
}
295295

296296
Type LLVMTypeConverter::convertFloatType(FloatType type) const {
297+
// Valid LLVM float types are used directly.
298+
if (LLVM::isCompatibleType(type))
299+
return type;
300+
301+
// F4, F6, F8 types are converted to integer types with the same bit width.
297302
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
298303
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
299304
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
300305
type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
301306
type.isFloat8E8M0FNU())
302307
return IntegerType::get(&getContext(), type.getWidth());
303-
return type;
308+
309+
// Other floating-point types: A custom type conversion rule must be
310+
// specified by the user.
311+
return Type();
304312
}
305313

306314
// Convert a `ComplexType` to an LLVM type. The result is a complex number

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4
3737
return %1 : vector<4xf32>
3838
}
3939

40+
// -----
41+
4042
// CHECK-LABEL: @ops
4143
func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
4244
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
@@ -84,9 +86,14 @@ func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
8486
%20 = arith.shrsi %arg2, %arg3 : i32
8587
// CHECK: = llvm.lshr %arg2, %arg3 : i32
8688
%21 = arith.shrui %arg2, %arg3 : i32
89+
// CHECK: arith.constant 2.000000e+00 : tf32
90+
// There is no type conversion rule for tf32.
91+
%22 = arith.constant 2.0 : tf32
8792
return %0, %10 : f32, i32
8893
}
8994

95+
// -----
96+
9097
// Checking conversion of index types to integers using i1, assuming no target
9198
// system would have a 1-bit address space. Otherwise, we would have had to
9299
// make this test dependent on the pointer size on the target system.
@@ -99,6 +106,8 @@ func.func @index_cast(%arg0: index, %arg1: i1) {
99106
return
100107
}
101108

109+
// -----
110+
102111
// CHECK-LABEL: @vector_index_cast
103112
func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
104113
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -108,6 +117,8 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
108117
return
109118
}
110119

120+
// -----
121+
111122
func.func @index_castui(%arg0: index, %arg1: i1) {
112123
// CHECK: = llvm.trunc %0 : i{{.*}} to i1
113124
%0 = arith.index_castui %arg0: index to i1
@@ -116,6 +127,8 @@ func.func @index_castui(%arg0: index, %arg1: i1) {
116127
return
117128
}
118129

130+
// -----
131+
119132
// CHECK-LABEL: @vector_index_castui
120133
func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
121134
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -125,6 +138,8 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
125138
return
126139
}
127140

141+
// -----
142+
128143
// Checking conversion of signed integer types to floating point.
129144
// CHECK-LABEL: @sitofp
130145
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
@@ -139,6 +154,8 @@ func.func @sitofp(%arg0 : i32, %arg1 : i64) {
139154
return
140155
}
141156

157+
// -----
158+
142159
// Checking conversion of integer vectors to floating point vector types.
143160
// CHECK-LABEL: @sitofp_vector
144161
func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -157,6 +174,8 @@ func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
157174
return
158175
}
159176

177+
// -----
178+
160179
// Checking conversion of unsigned integer types to floating point.
161180
// CHECK-LABEL: @uitofp
162181
func.func @uitofp(%arg0 : i32, %arg1 : i64) {
@@ -171,6 +190,8 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
171190
return
172191
}
173192

193+
// -----
194+
174195
// Checking conversion of integer types to floating point.
175196
// CHECK-LABEL: @fpext
176197
func.func @fpext(%arg0 : f16, %arg1 : f32) {
@@ -183,6 +204,8 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
183204
return
184205
}
185206

207+
// -----
208+
186209
// Checking conversion of integer types to floating point.
187210
// CHECK-LABEL: @fpext
188211
func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
@@ -195,6 +218,8 @@ func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
195218
return
196219
}
197220

221+
// -----
222+
198223
// Checking conversion of floating point to integer types.
199224
// CHECK-LABEL: @fptosi
200225
func.func @fptosi(%arg0 : f32, %arg1 : f64) {
@@ -209,6 +234,8 @@ func.func @fptosi(%arg0 : f32, %arg1 : f64) {
209234
return
210235
}
211236

237+
// -----
238+
212239
// Checking conversion of floating point vectors to integer vector types.
213240
// CHECK-LABEL: @fptosi_vector
214241
func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -227,6 +254,8 @@ func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
227254
return
228255
}
229256

257+
// -----
258+
230259
// Checking conversion of floating point to integer types.
231260
// CHECK-LABEL: @fptoui
232261
func.func @fptoui(%arg0 : f32, %arg1 : f64) {
@@ -241,6 +270,8 @@ func.func @fptoui(%arg0 : f32, %arg1 : f64) {
241270
return
242271
}
243272

273+
// -----
274+
244275
// Checking conversion of floating point vectors to integer vector types.
245276
// CHECK-LABEL: @fptoui_vector
246277
func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -259,6 +290,8 @@ func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
259290
return
260291
}
261292

293+
// -----
294+
262295
// Checking conversion of integer vectors to floating point vector types.
263296
// CHECK-LABEL: @uitofp_vector
264297
func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -277,6 +310,8 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
277310
return
278311
}
279312

313+
// -----
314+
280315
// Checking conversion of integer types to floating point.
281316
// CHECK-LABEL: @fptrunc
282317
func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
@@ -289,6 +324,8 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
289324
return
290325
}
291326

327+
// -----
328+
292329
// Checking conversion of integer types to floating point.
293330
// CHECK-LABEL: @fptrunc
294331
func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
@@ -301,6 +338,8 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
301338
return
302339
}
303340

341+
// -----
342+
304343
// CHECK-LABEL: experimental_constrained_fptrunc
305344
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
306345
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
@@ -316,6 +355,8 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
316355
return
317356
}
318357

358+
// -----
359+
319360
// Check sign and zero extension and truncation of integers.
320361
// CHECK-LABEL: @integer_extension_and_truncation
321362
func.func @integer_extension_and_truncation(%arg0 : i3) {
@@ -328,6 +369,8 @@ func.func @integer_extension_and_truncation(%arg0 : i3) {
328369
return
329370
}
330371

372+
// -----
373+
331374
// CHECK-LABEL: @integer_cast_0d_vector
332375
func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
333376
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
@@ -340,6 +383,8 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
340383
return
341384
}
342385

386+
// -----
387+
343388
// CHECK-LABEL: func @fcmp(%arg0: f32, %arg1: f32) {
344389
func.func @fcmp(f32, f32) -> () {
345390
^bb0(%arg0: f32, %arg1: f32):

mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,14 @@ func.func @index_arg(%arg0: index) -> index {
555555
return %arg1 : index
556556
}
557557

558+
// There is no type conversion rule for tf32, so vector<1xtf32> and, therefore,
559+
// the func op cannot be converted.
560+
// CHECK: func.func @non_convertible_arg_type({{.*}}: vector<1xtf32>)
561+
// CHECK: llvm.return
562+
func.func @non_convertible_arg_type(%arg: vector<1xtf32>) {
563+
return
564+
}
565+
558566
module attributes {transform.with_named_sequence} {
559567
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
560568
%func = transform.structured.match ops{["func.func"]} in %toplevel_module

0 commit comments

Comments
 (0)