Skip to content

Commit 2a5c301

Browse files
committed
[CIR][X86] Implement lowering for pmuldq / pmuludq builtins
This patch adds CIR codegen support for X86 pmuldq and pmuludq operations, covering the signed and unsigned variants across all supported vector widths. The builtins now lower to the expected CIR representation matching the semantics of the corresponding LLVM intrinsics.
1 parent 1f35b52 commit 2a5c301

File tree

5 files changed

+222
-2
lines changed

5 files changed

+222
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,40 @@ static mlir::Value emitX86FunnelShift(CIRGenBuilderTy &builder,
269269
mlir::ValueRange{op0, op1, amt});
270270
}
271271

272+
static mlir::Value emitX86Muldq(CIRGenBuilderTy &builder, mlir::Location loc,
273+
bool isSigned,
274+
SmallVectorImpl<mlir::Value> &ops,
275+
unsigned opTypePrimitiveSizeInBits) {
276+
mlir::Type ty = cir::VectorType::get(builder.getSInt64Ty(),
277+
opTypePrimitiveSizeInBits / 64);
278+
mlir::Value lhs = builder.createBitcast(loc, ops[0], ty);
279+
mlir::Value rhs = builder.createBitcast(loc, ops[1], ty);
280+
if (isSigned) {
281+
cir::ConstantOp shiftAmt =
282+
builder.getConstant(loc, cir::IntAttr::get(builder.getSInt64Ty(), 32));
283+
cir::VecSplatOp shiftSplatVecOp =
284+
cir::VecSplatOp::create(builder, loc, ty, shiftAmt.getResult());
285+
mlir::Value shiftSplatValue = shiftSplatVecOp.getResult();
286+
// In CIR, right-shift operations are automatically lowered to either an
287+
// arithmetic or logical shift depending on the operand type. The purpose
288+
// of the shifts here is to propagate the sign bit of the 32-bit input
289+
// into the upper bits of each vector lane.
290+
lhs = builder.createShift(loc, lhs, shiftSplatValue, true);
291+
lhs = builder.createShift(loc, lhs, shiftSplatValue, false);
292+
rhs = builder.createShift(loc, rhs, shiftSplatValue, true);
293+
rhs = builder.createShift(loc, rhs, shiftSplatValue, false);
294+
} else {
295+
cir::ConstantOp maskScalar = builder.getConstant(
296+
loc, cir::IntAttr::get(builder.getSInt64Ty(), 0xffffffff));
297+
cir::VecSplatOp mask =
298+
cir::VecSplatOp::create(builder, loc, ty, maskScalar.getResult());
299+
// Clear the upper bits
300+
lhs = builder.createAnd(loc, lhs, mask);
301+
rhs = builder.createAnd(loc, rhs, mask);
302+
}
303+
return builder.createMul(loc, lhs, rhs);
304+
}
305+
272306
mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
273307
const CallExpr *expr) {
274308
if (builtinID == Builtin::BI__builtin_cpu_is) {
@@ -1125,12 +1159,26 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
11251159
case X86::BI__builtin_ia32_sqrtph512:
11261160
case X86::BI__builtin_ia32_sqrtps512:
11271161
case X86::BI__builtin_ia32_sqrtpd512:
1162+
cgm.errorNYI(expr->getSourceRange(),
1163+
std::string("unimplemented X86 builtin call: ") +
1164+
getContext().BuiltinInfo.getName(builtinID));
1165+
return {};
11281166
case X86::BI__builtin_ia32_pmuludq128:
11291167
case X86::BI__builtin_ia32_pmuludq256:
1130-
case X86::BI__builtin_ia32_pmuludq512:
1168+
case X86::BI__builtin_ia32_pmuludq512: {
1169+
unsigned opTypePrimitiveSizeInBits =
1170+
cgm.getDataLayout().getTypeSizeInBits(ops[0].getType());
1171+
return emitX86Muldq(builder, getLoc(expr->getExprLoc()), /*isSigned*/ false,
1172+
ops, opTypePrimitiveSizeInBits);
1173+
}
11311174
case X86::BI__builtin_ia32_pmuldq128:
11321175
case X86::BI__builtin_ia32_pmuldq256:
1133-
case X86::BI__builtin_ia32_pmuldq512:
1176+
case X86::BI__builtin_ia32_pmuldq512: {
1177+
unsigned opTypePrimitiveSizeInBits =
1178+
cgm.getDataLayout().getTypeSizeInBits(ops[0].getType());
1179+
return emitX86Muldq(builder, getLoc(expr->getExprLoc()), /*isSigned*/ true,
1180+
ops, opTypePrimitiveSizeInBits);
1181+
}
11341182
case X86::BI__builtin_ia32_pternlogd512_mask:
11351183
case X86::BI__builtin_ia32_pternlogq512_mask:
11361184
case X86::BI__builtin_ia32_pternlogd128_mask:

clang/test/CIR/CodeGenBuiltins/X86/avx2-builtins.c

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,55 @@ __m256i test_mm256_shufflehi_epi16(__m256i a) {
5151
// OGCG: shufflevector <16 x i16> %{{.*}}, <16 x i16> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 7, i32 6, i32 6, i32 5, i32 8, i32 9, i32 10, i32 11, i32 15, i32 14, i32 14, i32 13>
5252
return _mm256_shufflehi_epi16(a, 107);
5353
}
54+
55+
__m256i test_mm256_mul_epu32(__m256i a, __m256i b) {
56+
// CIR-LABEL: _mm256_mul_epu32
57+
// CIR: [[BC_A:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<4 x !s64i>
58+
// CIR: [[BC_B:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<4 x !s64i>
59+
// CIR: [[MASK_SCALAR:%.*]] = cir.const #cir.int<4294967295> : !s64i
60+
// CIR: [[MASK_VEC:%.*]] = cir.vec.splat [[MASK_SCALAR]] : !s64i, !cir.vector<4 x !s64i>
61+
// CIR: [[AND_A:%.*]] = cir.binop(and, [[BC_A]], [[MASK_VEC]])
62+
// CIR: [[AND_B:%.*]] = cir.binop(and, [[BC_B]], [[MASK_VEC]])
63+
// CIR: [[MUL:%.*]] = cir.binop(mul, [[AND_A]], [[AND_B]])
64+
65+
// LLVM-LABEL: _mm256_mul_epu32
66+
// LLVM: and <4 x i64> %{{.*}}, splat (i64 4294967295)
67+
// LLVM: and <4 x i64> %{{.*}}, splat (i64 4294967295)
68+
// LLVM: mul <4 x i64> %{{.*}}, %{{.*}}
69+
70+
// OGCG-LABEL: _mm256_mul_epu32
71+
// OGCG: and <4 x i64> %{{.*}}, splat (i64 4294967295)
72+
// OGCG: and <4 x i64> %{{.*}}, splat (i64 4294967295)
73+
// OGCG: mul <4 x i64> %{{.*}}, %{{.*}}
74+
75+
return _mm256_mul_epu32(a, b);
76+
}
77+
78+
__m256i test_mm256_mul_epi32(__m256i a, __m256i b) {
79+
// CIR-LABEL: _mm256_mul_epi32
80+
// CIR: [[A64:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<4 x !s64i>
81+
// CIR: [[B64:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<4 x !s64i>
82+
// CIR: [[SC:%.*]] = cir.const #cir.int<32> : !s64i
83+
// CIR: [[SV:%.*]] = cir.vec.splat [[SC]] : !s64i, !cir.vector<4 x !s64i>
84+
// CIR: [[SHL_A:%.*]] = cir.shift(left, [[A64]] : !cir.vector<4 x !s64i>, [[SV]] : !cir.vector<4 x !s64i>)
85+
// CIR: [[ASHR_A:%.*]] = cir.shift(right, [[SHL_A]] : !cir.vector<4 x !s64i>, [[SV]] : !cir.vector<4 x !s64i>)
86+
// CIR: [[SHL_B:%.*]] = cir.shift(left, [[B64]] : !cir.vector<4 x !s64i>, [[SV]] : !cir.vector<4 x !s64i>)
87+
// CIR: [[ASHR_B:%.*]] = cir.shift(right, [[SHL_B]] : !cir.vector<4 x !s64i>, [[SV]] : !cir.vector<4 x !s64i>)
88+
// CIR: [[MUL:%.*]] = cir.binop(mul, [[ASHR_A]], [[ASHR_B]])
89+
90+
// LLVM-LABEL: _mm256_mul_epi32
91+
// LLVM: shl <4 x i64> %{{.*}}, splat (i64 32)
92+
// LLVM: ashr <4 x i64> %{{.*}}, splat (i64 32)
93+
// LLVM: shl <4 x i64> %{{.*}}, splat (i64 32)
94+
// LLVM: ashr <4 x i64> %{{.*}}, splat (i64 32)
95+
// LLVM: mul <4 x i64> %{{.*}}, %{{.*}}
96+
97+
// OGCG-LABEL: _mm256_mul_epi32
98+
// OGCG: shl <4 x i64> %{{.*}}, splat (i64 32)
99+
// OGCG: ashr <4 x i64> %{{.*}}, splat (i64 32)
100+
// OGCG: shl <4 x i64> %{{.*}}, splat (i64 32)
101+
// OGCG: ashr <4 x i64> %{{.*}}, splat (i64 32)
102+
// OGCG: mul <4 x i64> %{{.*}}, %{{.*}}
103+
104+
return _mm256_mul_epi32(a, b);
105+
}

clang/test/CIR/CodeGenBuiltins/X86/avx512f-builtins.c

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,3 +527,55 @@ __m512i test_mm512_ror_epi64(__m512i __A) {
527527
// OGCG: call <8 x i64> @llvm.fshr.v8i64(<8 x i64> %[[VAR]], <8 x i64> %[[VAR]], <8 x i64> splat (i64 5))
528528
return _mm512_ror_epi64(__A, 5);
529529
}
530+
531+
__m512i test_mm512_mul_epi32(__m512i __A, __m512i __B) {
532+
// CIR-LABEL: _mm512_mul_epi32
533+
// CIR: [[A64:%.*]] = cir.cast bitcast %{{.*}} : !cir.vector<16 x !s32i> -> !cir.vector<8 x !s64i>
534+
// CIR: [[B64:%.*]] = cir.cast bitcast %{{.*}} : !cir.vector<16 x !s32i> -> !cir.vector<8 x !s64i>
535+
// CIR: [[SC:%.*]] = cir.const #cir.int<32> : !s64i
536+
// CIR: [[SV:%.*]] = cir.vec.splat [[SC]] : !s64i, !cir.vector<8 x !s64i>
537+
// CIR: [[SHL_A:%.*]] = cir.shift(left, [[A64]] : !cir.vector<8 x !s64i>, [[SV]] : !cir.vector<8 x !s64i>)
538+
// CIR: [[ASHR_A:%.*]] = cir.shift(right, [[SHL_A]] : !cir.vector<8 x !s64i>, [[SV]] : !cir.vector<8 x !s64i>)
539+
// CIR: [[SHL_B:%.*]] = cir.shift(left, [[B64]] : !cir.vector<8 x !s64i>, [[SV]] : !cir.vector<8 x !s64i>)
540+
// CIR: [[ASHR_B:%.*]] = cir.shift(right, [[SHL_B]] : !cir.vector<8 x !s64i>, [[SV]] : !cir.vector<8 x !s64i>)
541+
// CIR: [[MUL:%.*]] = cir.binop(mul, [[ASHR_A]], [[ASHR_B]])
542+
543+
// LLVM-LABEL: _mm512_mul_epi32
544+
// LLVM: shl <8 x i64> %{{.*}}, splat (i64 32)
545+
// LLVM: ashr <8 x i64> %{{.*}}, splat (i64 32)
546+
// LLVM: shl <8 x i64> %{{.*}}, splat (i64 32)
547+
// LLVM: ashr <8 x i64> %{{.*}}, splat (i64 32)
548+
// LLVM: mul <8 x i64> %{{.*}}, %{{.*}}
549+
550+
// OGCG-LABEL: _mm512_mul_epi32
551+
// OGCG: shl <8 x i64> %{{.*}}, splat (i64 32)
552+
// OGCG: ashr <8 x i64> %{{.*}}, splat (i64 32)
553+
// OGCG: shl <8 x i64> %{{.*}}, splat (i64 32)
554+
// OGCG: ashr <8 x i64> %{{.*}}, splat (i64 32)
555+
// OGCG: mul <8 x i64> %{{.*}}, %{{.*}}
556+
557+
return _mm512_mul_epi32(__A, __B);
558+
}
559+
560+
__m512i test_mm512_mul_epu32(__m512i __A, __m512i __B) {
561+
// CIR-LABEL: _mm512_mul_epu32
562+
// CIR: [[BC_A:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<8 x !s64i>
563+
// CIR: [[BC_B:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<8 x !s64i>
564+
// CIR: [[MASK_SCALAR:%.*]] = cir.const #cir.int<4294967295> : !s64i
565+
// CIR: [[MASK_VEC:%.*]] = cir.vec.splat [[MASK_SCALAR]] : !s64i, !cir.vector<8 x !s64i>
566+
// CIR: [[AND_A:%.*]] = cir.binop(and, [[BC_A]], [[MASK_VEC]])
567+
// CIR: [[AND_B:%.*]] = cir.binop(and, [[BC_B]], [[MASK_VEC]])
568+
// CIR: [[MUL:%.*]] = cir.binop(mul, [[AND_A]], [[AND_B]])
569+
570+
// LLVM-LABEL: _mm512_mul_epu32
571+
// LLVM: and <8 x i64> %{{.*}}, splat (i64 4294967295)
572+
// LLVM: and <8 x i64> %{{.*}}, splat (i64 4294967295)
573+
// LLVM: mul <8 x i64> %{{.*}}, %{{.*}}
574+
575+
// OGCG-LABEL: _mm512_mul_epu32
576+
// OGCG: and <8 x i64> %{{.*}}, splat (i64 4294967295)
577+
// OGCG: and <8 x i64> %{{.*}}, splat (i64 4294967295)
578+
// OGCG: mul <8 x i64> %{{.*}}, %{{.*}}
579+
580+
return _mm512_mul_epu32(__A, __B);
581+
}

clang/test/CIR/CodeGenBuiltins/X86/sse2-builtins.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,26 @@ __m128i test_mm_shuffle_epi32(__m128i A) {
159159
// OGCG: shufflevector <4 x i32> %{{.*}}, <4 x i32> poison, <4 x i32> <i32 2, i32 3, i32 0, i32 1>
160160
return _mm_shuffle_epi32(A, 0x4E);
161161
}
162+
163+
__m128i test_mm_mul_epu32(__m128i A, __m128i B) {
164+
// CIR-LABEL: _mm_mul_epu32
165+
// CIR: [[BC_A:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<2 x !s64i>
166+
// CIR: [[BC_B:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<2 x !s64i>
167+
// CIR: [[MASK_SCALAR:%.*]] = cir.const #cir.int<4294967295> : !s64i
168+
// CIR: [[MASK_VEC:%.*]] = cir.vec.splat [[MASK_SCALAR]] : !s64i, !cir.vector<2 x !s64i>
169+
// CIR: [[AND_A:%.*]] = cir.binop(and, [[BC_A]], [[MASK_VEC]])
170+
// CIR: [[AND_B:%.*]] = cir.binop(and, [[BC_B]], [[MASK_VEC]])
171+
// CIR: [[MUL:%.*]] = cir.binop(mul, [[AND_A]], [[AND_B]])
172+
173+
// LLVM-LABEL: _mm_mul_epu32
174+
// LLVM: and <2 x i64> %{{.*}}, splat (i64 4294967295)
175+
// LLVM: and <2 x i64> %{{.*}}, splat (i64 4294967295)
176+
// LLVM: mul <2 x i64> %{{.*}}, %{{.*}}
177+
178+
// OGCG-LABEL: _mm_mul_epu32
179+
// OGCG: and <2 x i64> %{{.*}}, splat (i64 4294967295)
180+
// OGCG: and <2 x i64> %{{.*}}, splat (i64 4294967295)
181+
// OGCG: mul <2 x i64> %{{.*}}, %{{.*}}
182+
183+
return _mm_mul_epu32(A, B);
184+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: %clang_cc1 -x c -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +sse4.1 -fclangir -emit-cir -o %t.cir -Wall -Werror -Wsign-conversion
2+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
3+
// RUN: %clang_cc1 -x c -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +sse4.1 -fclangir -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion
4+
// RUN: FileCheck --check-prefixes=LLVM --input-file=%t.ll %s
5+
6+
// RUN: %clang_cc1 -x c++ -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +sse4.1 -fclangir -emit-cir -o %t.cir -Wall -Werror -Wsign-conversion
7+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
8+
// RUN: %clang_cc1 -x c++ -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +sse4.1 -fclangir -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion
9+
// RUN: FileCheck --check-prefixes=LLVM --input-file=%t.ll %s
10+
11+
// RUN: %clang_cc1 -x c -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-apple-darwin -target-feature +sse4.1 -emit-llvm -o - -Wall -Werror -Wsign-conversion | FileCheck %s --check-prefixes=OGCG
12+
// RUN: %clang_cc1 -x c -flax-vector-conversions=none -fms-extensions -fms-compatibility -ffreestanding %s -triple=x86_64-windows-msvc -target-feature +sse4.1 -emit-llvm -o - -Wall -Werror -Wsign-conversion | FileCheck %s --check-prefixes=OGCG
13+
// RUN: %clang_cc1 -x c++ -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-apple-darwin -target-feature +sse4.1 -emit-llvm -o - -Wall -Werror -Wsign-conversion | FileCheck %s --check-prefixes=OGCG
14+
// RUN: %clang_cc1 -x c++ -flax-vector-conversions=none -fms-extensions -fms-compatibility -ffreestanding %s -triple=x86_64-windows-msvc -target-feature +sse4.1 -emit-llvm -o - -Wall -Werror -Wsign-conversion | FileCheck %s --check-prefixes=OGCG
15+
16+
#include <immintrin.h>
17+
18+
__m128i test_mm_mul_epi32(__m128i x, __m128i y) {
19+
// CIR-LABEL: _mm_mul_epi32
20+
// CIR: [[A64:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<2 x !s64i>
21+
// CIR: [[B64:%.*]] = cir.cast bitcast %{{.*}} : {{.*}} -> !cir.vector<2 x !s64i>
22+
// CIR: [[SC:%.*]] = cir.const #cir.int<32> : !s64i
23+
// CIR: [[SV:%.*]] = cir.vec.splat [[SC]] : !s64i, !cir.vector<2 x !s64i>
24+
// CIR: [[SHL_A:%.*]] = cir.shift(left, [[A64]] : !cir.vector<2 x !s64i>, [[SV]] : !cir.vector<2 x !s64i>)
25+
// CIR: [[ASHR_A:%.*]] = cir.shift(right, [[SHL_A]] : !cir.vector<2 x !s64i>, [[SV]] : !cir.vector<2 x !s64i>)
26+
// CIR: [[SHL_B:%.*]] = cir.shift(left, [[B64]] : !cir.vector<2 x !s64i>, [[SV]] : !cir.vector<2 x !s64i>)
27+
// CIR: [[ASHR_B:%.*]] = cir.shift(right, [[SHL_B]] : !cir.vector<2 x !s64i>, [[SV]] : !cir.vector<2 x !s64i>)
28+
// CIR: [[MUL:%.*]] = cir.binop(mul, [[ASHR_A]], [[ASHR_B]])
29+
30+
// LLVM-LABEL: _mm_mul_epi32
31+
// LLVM: shl <2 x i64> %{{.*}}, splat (i64 32)
32+
// LLVM: ashr <2 x i64> %{{.*}}, splat (i64 32)
33+
// LLVM: shl <2 x i64> %{{.*}}, splat (i64 32)
34+
// LLVM: ashr <2 x i64> %{{.*}}, splat (i64 32)
35+
// LLVM: mul <2 x i64> %{{.*}}, %{{.*}}
36+
37+
// OGCG-LABEL: _mm_mul_epi32
38+
// OGCG: shl <2 x i64> %{{.*}}, splat (i64 32)
39+
// OGCG: ashr <2 x i64> %{{.*}}, splat (i64 32)
40+
// OGCG: shl <2 x i64> %{{.*}}, splat (i64 32)
41+
// OGCG: ashr <2 x i64> %{{.*}}, splat (i64 32)
42+
// OGCG: mul <2 x i64> %{{.*}}, %{{.*}}
43+
44+
return _mm_mul_epi32(x, y);
45+
}

0 commit comments

Comments
 (0)