Skip to content

Commit e556764

Browse files
committed
cast to unsigned properly
1 parent 4c41ac6 commit e556764

File tree

2 files changed

+35
-21
lines changed

2 files changed

+35
-21
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,27 +95,35 @@ static mlir::Value getMaskVecValue(CIRGenFunction &cgf, const CallExpr *expr,
9595
static mlir::Value emitX86FunnelShift(CIRGenFunction &cgf, const CallExpr *e,
9696
mlir::Value &op0, mlir::Value &op1,
9797
mlir::Value &amt, bool isRight) {
98-
auto ty = op0.getType();
98+
auto &builder = cgf.getBuilder();
99+
auto op0Ty = op0.getType();
99100

100101
// Amount may be scalar immediate, in which case create a splat vector.
101102
// Funnel shifts amounts are treated as modulo and types are all power-of-2
102103
// so we only care about the lowest log2 bits anyway.
103-
if (amt.getType() != ty) {
104-
auto vecTy = mlir::cast<cir::VectorType>(ty);
105-
104+
if (amt.getType() != op0Ty) {
105+
auto vecTy = mlir::cast<cir::VectorType>(op0Ty);
106106
auto numElems = vecTy.getSize();
107-
auto vecElemType = mlir::cast<cir::IntType>(vecTy.getElementType());
108-
auto signlessType =
109-
cir::IntType::get(&cgf.getMLIRContext(), vecElemType.getWidth(), false);
110-
amt = cgf.getBuilder().createIntCast(amt, signlessType);
111107

112-
amt = cir::VecSplatOp::create(cgf.getBuilder(), cgf.getLoc(e->getExprLoc()),
113-
cir::VectorType::get(signlessType, numElems),
114-
amt);
108+
auto amtTy = mlir::cast<cir::IntType>(amt.getType());
109+
auto vecElemTy = mlir::cast<cir::IntType>(vecTy.getElementType());
110+
111+
// Cast to same width unsigned if not already unsigned.
112+
if (amtTy.isSigned()) {
113+
auto unsignedAmtTy = builder.getUIntNTy(amtTy.getWidth());
114+
amt = builder.createIntCast(amt,
115+
builder.getUIntNTy(unsignedAmtTy.getWidth()));
116+
}
117+
// Cast the unsigned `amt` to operand element type's width unsigned.
118+
auto unsingedVecElemType = builder.getUIntNTy(vecElemTy.getWidth());
119+
amt = builder.createIntCast(amt, unsingedVecElemType);
120+
amt = cir::VecSplatOp::create(
121+
builder, cgf.getLoc(e->getExprLoc()),
122+
cir::VectorType::get(unsingedVecElemType, numElems), amt);
115123
}
116124

117125
const std::string intrinsicName = isRight ? "fshr" : "fshl";
118-
return emitIntrinsicCallOp(cgf, e, intrinsicName, ty,
126+
return emitIntrinsicCallOp(cgf, e, intrinsicName, op0Ty,
119127
mlir::ValueRange{op0, op1, amt});
120128
}
121129

clang/test/CIR/CodeGen/X86/xop-builtins.c

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ __m128i test_mm_roti_epi8(__m128i a) {
4343

4444
__m128i test_mm_roti_epi16(__m128i a) {
4545
// CIR-LABEL: test_mm_roti_epi16
46-
// CIR: {{%.*}} = cir.cast integral {{%.*}} : !{{[us]}}8i -> !u16i
47-
// CIR: {{%.*}} = cir.vec.splat {{%.*}} : !{{[us]}}16i, !cir.vector<8 x !{{[us]}}16i>
48-
// CIR: {{%.*}} = cir.call_llvm_intrinsic "fshl" {{.*}} : (!cir.vector<8 x !{{[su]}}16i>, !cir.vector<8 x !{{[su]}}16i>, !cir.vector<8 x !{{[su]}}16i>) -> !cir.vector<8 x !{{[su]}}16i>
46+
// CIR: {{%.*}} = cir.cast integral {{%.*}} : !u8i -> !u16i
47+
// CIR: {{%.*}} = cir.vec.splat {{%.*}} : !{{[us]}}16i, !cir.vector<8 x !u16i>
48+
// CIR: {{%.*}} = cir.call_llvm_intrinsic "fshl" {{.*}} : (!cir.vector<8 x !{{[su]}}16i>, !cir.vector<8 x !{{[su]}}16i>, !cir.vector<8 x !u16i>) -> !cir.vector<8 x !{{[su]}}16i>
4949
// LLVM-LABEL: test_mm_roti_epi16
5050
// LLVM: %[[CASTED_VAR:.*]] = bitcast <2 x i64> {{%.*}} to <8 x i16>
5151
// LLVM: {{%.*}} = call <8 x i16> @llvm.fshl.v8i16(<8 x i16> %[[CASTED_VAR]], <8 x i16> %[[CASTED_VAR]], <8 x i16> splat (i16 50))
@@ -58,17 +58,23 @@ __m128i test_mm_roti_epi16(__m128i a) {
5858
//NOTE: This only works as I expect for CIR but not for LLVMIR
5959
__m128i test_mm_roti_epi32(__m128i a) {
6060
// CIR-LABEL: test_mm_roti_epi32
61-
// CIR: {{%.*}} = cir.cast integral {{%.*}} : !{{[us]}}8i -> !u32i
62-
// CIR: {{%.*}} = cir.vec.splat {{%.*}} : !{{[us]}}32i, !cir.vector<4 x !{{[us]}}32i>
63-
// CIR: {{%.*}} = cir.call_llvm_intrinsic "fshl" {{.*}} : (!cir.vector<4 x !{{[su]}}32i>, !cir.vector<4 x !{{[su]}}32i>, !cir.vector<4 x !{{[su]}}32i>) -> !cir.vector<4 x !{{[su]}}32i>
61+
// CIR: {{%.*}} = cir.cast integral {{%.*}} : !u8i -> !u32i
62+
// CIR: {{%.*}} = cir.vec.splat {{%.*}} : !{{[us]}}32i, !cir.vector<4 x !u32i>
63+
// CIR: {{%.*}} = cir.call_llvm_intrinsic "fshl" {{.*}} : (!cir.vector<4 x !{{[su]}}32i>, !cir.vector<4 x !{{[su]}}32i>, !cir.vector<4 x !u32i>) -> !cir.vector<4 x !{{[su]}}32i>
64+
// LLVM-LABEL: test_mm_roti_epi32
65+
// LLVM: %[[CASTED_VAR:.*]] = bitcast <2 x i64> {{%.*}} to <4 x i32>
66+
// LLVM: {{%.*}} = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %[[CASTED_VAR]], <4 x i32> %[[CASTED_VAR]], <4 x i32> splat (i32 226))
67+
// OGCG-LABEL: test_mm_roti_epi32
68+
// OGCG: %[[CASTED_VAR:.*]] = bitcast <2 x i64> {{%.*}} to <4 x i32>
69+
// OGCG: {{%.*}} = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %[[CASTED_VAR]], <4 x i32> %[[CASTED_VAR]], <4 x i32> splat (i32 226))
6470
return _mm_roti_epi32(a, -30);
6571
}
6672

6773
__m128i test_mm_roti_epi64(__m128i a) {
6874
// CIR-LABEL: test_mm_roti_epi64
69-
// CIR: {{%.*}} = cir.cast integral {{%.*}} : !{{[us]}}8i -> !u64i
70-
// CIR: {{%.*}} = cir.vec.splat {{%.*}} : !{{.}}64i, !cir.vector<2 x !{{[us]}}64i>
71-
// CIR: {{%.*}} = cir.call_llvm_intrinsic "fshl" {{.*}} : (!cir.vector<2 x !{{[su]}}64i>, !cir.vector<2 x !{{[su]}}64i>, !cir.vector<2 x !u64i>) -> !cir.vector<2 x !{{[su]}}64i>
75+
// CIR: {{%.*}} = cir.cast integral {{%.*}} : !u8i -> !u64i
76+
// CIR: {{%.*}} = cir.vec.splat {{%.*}} : !u64i, !cir.vector<2 x !u64i>
77+
// CIR: {{%.*}} = cir.call_llvm_intrinsic "fshl" {{.*}} : (!cir.vector<2 x !{{[su]}}64i>, !cir.vector<2 x !{{[su]}}64i>, !cir.vector<2 x !u64i>) -> !cir.vector<2 x !s64i>
7278
// LLVM-LABEL: test_mm_roti_epi64
7379
// LLVM: %[[VAR:.*]] = load <2 x i64>, ptr {{%.*}}, align 16
7480
// LLVM: {{%.*}} = call <2 x i64> @llvm.fshl.v2i64(<2 x i64> %[[VAR]], <2 x i64> %[[VAR]], <2 x i64> splat (i64 100))

0 commit comments

Comments
 (0)