Skip to content

Commit c8bd004

Browse files
tommymcmlanza
authored andcommitted
[CIR] Improved cir::CastOp verifier to allow bitcasts between types of the same size (llvm#1728)
The `cir::CastOp::verify` method was overly conservative, and would fail on any `bitcast` from vector to scalar or scalar to vector. Change List: - Extends the `cir::CastOp::verify` method to check if the source and result types are the same size using the `mlir::DataLayout` of the current scope, and succeeds if the sizes match. - Extends the CodeGen vectype tests with vector to scalar, scalar to vector and vector to vector conversions. - Extends the IR invalid tests with vector to scalar and scalar to vector conversions with different source and result sizes.
1 parent 07b7c6a commit c8bd004

File tree

3 files changed

+57
-5
lines changed

3 files changed

+57
-5
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "mlir/IR/Location.h"
3636
#include "mlir/IR/OpDefinition.h"
3737
#include "mlir/IR/OpImplementation.h"
38+
#include "mlir/IR/Operation.h"
3839
#include "mlir/IR/StorageUniquerSupport.h"
3940
#include "mlir/IR/TypeUtilities.h"
4041
#include "mlir/Interfaces/DataLayoutInterfaces.h"
@@ -570,6 +571,17 @@ LogicalResult cir::CastOp::verify() {
570571
mlir::isa<cir::MethodType>(resType))
571572
return success();
572573

574+
// Handle scalar to vector and vector to scalar conversions.
575+
if (mlir::isa<cir::VectorType>(getSrc().getType()) !=
576+
mlir::isa<cir::VectorType>(getType())) {
577+
// The source and result must be the same size.
578+
mlir::DataLayout dataLayout(
579+
getOperation()->getParentOfType<mlir::DataLayoutOpInterface>());
580+
if (dataLayout.getTypeSize(getSrc().getType()) ==
581+
dataLayout.getTypeSize(getType()))
582+
return success();
583+
}
584+
573585
// This is the only cast kind where we don't want vector types to decay
574586
// into the element type.
575587
if ((!mlir::isa<cir::VectorType>(getSrc().getType()) ||

clang/test/CIR/CodeGen/vectype.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,31 @@ void vector_int_test(int x, unsigned short usx) {
120120

121121
// Shifts
122122
vi4 w = a << b;
123-
// CHECK: %{{[0-9]+}} = cir.shift(left, {{%.*}} : !cir.vector<!s32i x 4>,
123+
// CHECK: %{{[0-9]+}} = cir.shift(left, {{%.*}} : !cir.vector<!s32i x 4>,
124124
// CHECK-SAME: {{%.*}} : !cir.vector<!s32i x 4>) -> !cir.vector<!s32i x 4>
125125
vi4 y = a >> b;
126-
// CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector<!s32i x 4>,
126+
// CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector<!s32i x 4>,
127127
// CHECK-SAME: {{%.*}} : !cir.vector<!s32i x 4>) -> !cir.vector<!s32i x 4>
128128

129-
vus2 z = { usx, usx };
129+
vus2 z = { usx, usx };
130130
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}} : !u16i, !u16i) : !cir.vector<!u16i x 2>
131131
vus2 zamt = { 3, 4 };
132132
// CHECK: %{{[0-9]+}} = cir.const #cir.const_vector<[#cir.int<3> : !u16i, #cir.int<4> : !u16i]> : !cir.vector<!u16i x 2>
133133
vus2 zzz = z >> zamt;
134-
// CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector<!u16i x 2>,
135-
// CHECK-SAME: {{%.*}} : !cir.vector<!u16i x 2>) -> !cir.vector<!u16i x 2>
134+
// CHECK: %{{[0-9]+}} = cir.shift(right, {{%.*}} : !cir.vector<!u16i x 2>,
135+
// CHECK-SAME: {{%.*}} : !cir.vector<!u16i x 2>) -> !cir.vector<!u16i x 2>
136+
137+
// Vector to scalar conversion
138+
unsigned int zi = (unsigned int)z;
139+
// CHECK: %{{[0-9]+}} = cir.cast(bitcast, {{%.*}} : !cir.vector<!u16i x 2>), !u32i
140+
141+
// Scalar to vector conversion
142+
vus2 zz = (vus2)zi;
143+
// CHECK: %{{[0-9]+}} = cir.cast(bitcast, {{%.*}} : !u32i), !cir.vector<!u16i x 2>
144+
145+
// Vector to vector conversion
146+
vll2 aaa = (vll2)a;
147+
// CHECK: %{{[0-9]+}} = cir.cast(bitcast, {{%.*}} : !cir.vector<!s32i x 4>), !cir.vector<!s64i x 2>
136148
}
137149

138150
void vector_double_test(int x, double y) {

clang/test/CIR/IR/invalid.cir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,7 +1410,35 @@ module {
14101410
// expected-error@+1 {{'cir.cast' op result type address space does not match the address space of the operand}}
14111411
%1 = cir.cast(bitcast, %0 : !cir.ptr<!s32i>), !cir.ptr<!s32i, addrspace(offload_local)>
14121412
}
1413+
}
1414+
1415+
// -----
1416+
1417+
!s16i = !cir.int<s, 16>
1418+
!s64i = !cir.int<s, 64>
14131419

1420+
module {
1421+
cir.func @test_bitcast_vec2scalar_diff_size() {
1422+
%0 = cir.const #cir.int<1> : !s16i
1423+
%1 = cir.vec.create(%0, %0 : !s16i, !s16i) : !cir.vector<!s16i x 2>
1424+
// expected-error@+1 {{'cir.cast' op requires !cir.ptr or !cir.vector type for source and result}}
1425+
%2 = cir.cast(bitcast, %1 : !cir.vector<!s16i x 2>), !s64i
1426+
cir.return
1427+
}
1428+
}
1429+
1430+
// -----
1431+
1432+
!s32i = !cir.int<s, 32>
1433+
!s64i = !cir.int<s, 64>
1434+
1435+
module {
1436+
cir.func @test_bitcast_scalar2vec_diff_size() {
1437+
%0 = cir.const #cir.int<1> : !s64i
1438+
// expected-error@+1 {{'cir.cast' op requires !cir.ptr or !cir.vector type for source and result}}
1439+
%1 = cir.cast(bitcast, %0 : !s64i), !cir.vector<!s32i x 4>
1440+
cir.return
1441+
}
14141442
}
14151443

14161444
// -----

0 commit comments

Comments
 (0)