Skip to content

Commit 491504e

Browse files
committed
[Clang] VectorExprEvaluator::VisitCallExpr / InterpretBuiltin - allow AVX/AVX512 subvector insertion intrinsics to be used in constexpr #157709
1 parent bb263f5 commit 491504e

File tree

4 files changed

+219
-8
lines changed

4 files changed

+219
-8
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,6 @@ let Features = "avx", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWid
502502
let Features = "avx", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
503503
def vpermilpd256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, _Constant int)">;
504504
def vpermilps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Constant int)">;
505-
def vinsertf128_pd256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, _Vector<2, double>, _Constant int)">;
506-
def vinsertf128_ps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Vector<4, float>, _Constant int)">;
507-
def vinsertf128_si256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>, _Constant int)">;
508505
def sqrtpd256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>)">;
509506
def sqrtps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>)">;
510507
def rsqrtps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>)">;
@@ -513,6 +510,12 @@ let Features = "avx", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in
513510
def roundps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Constant int)">;
514511
}
515512

513+
let Features = "avx", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
514+
def vinsertf128_pd256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, _Vector<2, double>, _Constant int)">;
515+
def vinsertf128_ps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Vector<4, float>, _Constant int)">;
516+
def vinsertf128_si256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>, _Constant int)">;
517+
}
518+
516519
let Features = "avx", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
517520
def vtestzpd : X86Builtin<"int(_Vector<2, double>, _Vector<2, double>)">;
518521
def vtestcpd : X86Builtin<"int(_Vector<2, double>, _Vector<2, double>)">;
@@ -609,6 +612,9 @@ let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] i
609612
def permti256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>, _Constant int)">;
610613
def permdi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">;
611614
def extract128i256 : X86Builtin<"_Vector<2, long long int>(_Vector<4, long long int>, _Constant int)">;
615+
}
616+
617+
let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
612618
def insert128i256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>, _Constant int)">;
613619
}
614620

@@ -2945,29 +2951,29 @@ let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256
29452951
def extracti32x4_256_mask : X86Builtin<"_Vector<4, int>(_Vector<8, int>, _Constant int, _Vector<4, int>, unsigned char)">;
29462952
}
29472953

2948-
let Features = "avx512dq", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
2954+
let Features = "avx512dq", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
29492955
def insertf32x8 : X86Builtin<"_Vector<16, float>(_Vector<16, float>, _Vector<8, float>, _Constant int)">;
29502956
def insertf64x2_512 : X86Builtin<"_Vector<8, double>(_Vector<8, double>, _Vector<2, double>, _Constant int)">;
29512957
def inserti32x8 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<8, int>, _Constant int)">;
29522958
def inserti64x2_512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<2, long long int>, _Constant int)">;
29532959
}
29542960

2955-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
2961+
let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
29562962
def insertf64x4 : X86Builtin<"_Vector<8, double>(_Vector<8, double>, _Vector<4, double>, _Constant int)">;
29572963
def inserti64x4 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<4, long long int>, _Constant int)">;
29582964
}
29592965

2960-
let Features = "avx512dq,avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
2966+
let Features = "avx512dq,avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
29612967
def insertf64x2_256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, _Vector<2, double>, _Constant int)">;
29622968
def inserti64x2_256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>, _Constant int)">;
29632969
}
29642970

2965-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
2971+
let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
29662972
def insertf32x4_256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Vector<4, float>, _Constant int)">;
29672973
def inserti32x4_256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>, _Constant int)">;
29682974
}
29692975

2970-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
2976+
let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
29712977
def insertf32x4 : X86Builtin<"_Vector<16, float>(_Vector<16, float>, _Vector<4, float>, _Constant int)">;
29722978
def inserti32x4 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<4, int>, _Constant int)">;
29732979
}

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2914,6 +2914,56 @@ static bool interp__builtin_elementwise_triop(
29142914
return true;
29152915
}
29162916

2917+
static bool interp__builtin_x86_insert_subvector(InterpState &S, CodePtr OpPC,
2918+
const CallExpr *Call,
2919+
unsigned ID) {
2920+
assert(Call->getNumArgs() == 3);
2921+
2922+
PrimType ImmPT = *S.getContext().classify(Call->getArg(2));
2923+
APSInt ImmAPS = popToAPSInt(S.Stk, ImmPT);
2924+
uint64_t Index = ImmAPS.getZExtValue();
2925+
2926+
const Pointer &SubVec = S.Stk.pop<Pointer>();
2927+
if (!SubVec.getFieldDesc()->isPrimitiveArray()) {
2928+
return Invalid(S, OpPC);
2929+
}
2930+
2931+
const Pointer &DstVec = S.Stk.pop<Pointer>();
2932+
if (!DstVec.getFieldDesc()->isPrimitiveArray()) {
2933+
return Invalid(S, OpPC);
2934+
}
2935+
2936+
const Pointer &Result = S.Stk.peek<Pointer>();
2937+
2938+
unsigned DstElements = DstVec.getNumElems();
2939+
unsigned SubElements = SubVec.getNumElems();
2940+
2941+
if (SubElements == 0 || DstElements == 0 || (DstElements % SubElements) != 0)
2942+
return Invalid(S, OpPC);
2943+
2944+
unsigned NumLanes = DstElements / SubElements;
2945+
unsigned Lane = static_cast<unsigned>(Index % NumLanes);
2946+
2947+
QualType ElemType = DstVec.getFieldDesc()->getElemQualType();
2948+
PrimType ElemPT = *S.getContext().classify(ElemType);
2949+
2950+
unsigned InsertPos = Lane * SubElements;
2951+
2952+
TYPE_SWITCH(ElemPT, {
2953+
for (unsigned i = 0; i < DstElements; ++i) {
2954+
Result.elem<T>(i) = DstVec.elem<T>(i);
2955+
}
2956+
2957+
for (unsigned i = 0; i < SubElements; ++i) {
2958+
Result.elem<T>(InsertPos + i) = SubVec.elem<T>(i);
2959+
}
2960+
});
2961+
2962+
Result.initializeAllElements();
2963+
2964+
return true;
2965+
}
2966+
29172967
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
29182968
uint32_t BuiltinID) {
29192969
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -3572,6 +3622,24 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
35723622
return interp__builtin_elementwise_triop(S, OpPC, Call,
35733623
llvm::APIntOps::fshr);
35743624

3625+
case X86::BI__builtin_ia32_insertf32x4_256:
3626+
case X86::BI__builtin_ia32_inserti32x4_256:
3627+
case X86::BI__builtin_ia32_insertf64x2_256:
3628+
case X86::BI__builtin_ia32_inserti64x2_256:
3629+
case X86::BI__builtin_ia32_insertf32x4:
3630+
case X86::BI__builtin_ia32_inserti32x4:
3631+
case X86::BI__builtin_ia32_insertf64x2_512:
3632+
case X86::BI__builtin_ia32_inserti64x2_512:
3633+
case X86::BI__builtin_ia32_insertf32x8:
3634+
case X86::BI__builtin_ia32_inserti32x8:
3635+
case X86::BI__builtin_ia32_insertf64x4:
3636+
case X86::BI__builtin_ia32_inserti64x4:
3637+
case X86::BI__builtin_ia32_vinsertf128_ps256:
3638+
case X86::BI__builtin_ia32_vinsertf128_pd256:
3639+
case X86::BI__builtin_ia32_vinsertf128_si256:
3640+
case X86::BI__builtin_ia32_insert128i256:
3641+
return interp__builtin_x86_insert_subvector(S, OpPC, Call, BuiltinID);
3642+
35753643
default:
35763644
S.FFDiag(S.Current->getLocation(OpPC),
35773645
diag::note_invalid_subexpr_in_const_expr)

clang/lib/AST/ExprConstant.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12128,6 +12128,56 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1212812128

1212912129
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
1213012130
}
12131+
12132+
case X86::BI__builtin_ia32_insertf32x4_256:
12133+
case X86::BI__builtin_ia32_inserti32x4_256:
12134+
case X86::BI__builtin_ia32_insertf64x2_256:
12135+
case X86::BI__builtin_ia32_inserti64x2_256:
12136+
case X86::BI__builtin_ia32_insertf32x4:
12137+
case X86::BI__builtin_ia32_inserti32x4:
12138+
case X86::BI__builtin_ia32_insertf64x2_512:
12139+
case X86::BI__builtin_ia32_inserti64x2_512:
12140+
case X86::BI__builtin_ia32_insertf32x8:
12141+
case X86::BI__builtin_ia32_inserti32x8:
12142+
case X86::BI__builtin_ia32_insertf64x4:
12143+
case X86::BI__builtin_ia32_inserti64x4:
12144+
case X86::BI__builtin_ia32_vinsertf128_ps256:
12145+
case X86::BI__builtin_ia32_vinsertf128_pd256:
12146+
case X86::BI__builtin_ia32_vinsertf128_si256:
12147+
case X86::BI__builtin_ia32_insert128i256: {
12148+
APValue SourceDst, SourceSub;
12149+
if (!EvaluateAsRValue(Info, E->getArg(0), SourceDst) ||
12150+
!EvaluateAsRValue(Info, E->getArg(1), SourceSub))
12151+
return false;
12152+
12153+
APSInt Imm;
12154+
if (!EvaluateInteger(E->getArg(2), Imm, Info))
12155+
return false;
12156+
12157+
if (!SourceDst.isVector() || !SourceSub.isVector())
12158+
return false;
12159+
12160+
unsigned DstLen = SourceDst.getVectorLength();
12161+
unsigned SubLen = SourceSub.getVectorLength();
12162+
if (SubLen == 0 || DstLen == 0 || (DstLen % SubLen) != 0)
12163+
return false;
12164+
12165+
unsigned NumLanes = DstLen / SubLen;
12166+
unsigned LaneIdx = (Imm.getZExtValue() % NumLanes) * SubLen;
12167+
12168+
SmallVector<APValue, 16> ResultElements;
12169+
ResultElements.reserve(DstLen);
12170+
12171+
for (unsigned EltNum = 0; EltNum < DstLen; ++EltNum) {
12172+
if (EltNum >= LaneIdx && EltNum < LaneIdx + SubLen) {
12173+
ResultElements.push_back(SourceSub.getVectorElt(EltNum - LaneIdx));
12174+
} else {
12175+
ResultElements.push_back(SourceDst.getVectorElt(EltNum));
12176+
}
12177+
}
12178+
12179+
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
12180+
}
1213112181
}
1213212182
}
1213312183

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// REQUIRES: x86-registered-target
2+
// RUN: %clang_cc1 -ffreestanding -triple x86_64-unknown-linux-gnu -O0 -target-cpu skylake-avx512 -std=c++17 -emit-llvm -o - %s | FileCheck %s
3+
// RUN: %clang_cc1 -ffreestanding -triple x86_64-unknown-linux-gnu -O0 -target-cpu skylake-avx512 -std=c++17 -fexperimental-new-constant-interpreter -emit-llvm -o - %s | FileCheck %s
4+
5+
#include <immintrin.h>
6+
#include "builtin_test_helpers.h"
7+
8+
//
9+
// AVX256 Insert Tests
10+
//
11+
12+
__m256 test_mm256_insertf32x4(__m256 A, __m128 B) {
13+
// CHECK-LABEL: test_mm256_insertf32x4
14+
return _mm256_insertf32x4(A, B, 1);
15+
}
16+
17+
// Insert 128-bit float vector into upper lane
18+
TEST_CONSTEXPR(match_m256(_mm256_insertf32x4(_mm256_set1_ps(1.0f), _mm_set_ps(40.0f, 30.0f, 20.0f, 10.0f), 1), 1.0f, 1.0f, 1.0f, 1.0f, 10.0f, 20.0f, 30.0f, 40.0f));
19+
20+
__m256i test_mm256_inserti32x4(__m256i A, __m128i B) {
21+
// CHECK-LABEL: test_mm256_inserti32x4
22+
return _mm256_inserti32x4(A, B, 0);
23+
}
24+
25+
// Insert 128-bit integer vector into lower lane
26+
TEST_CONSTEXPR(match_v8si(_mm256_inserti32x4(_mm256_set1_epi32(1), _mm_set_epi32(40, 30, 20, 10), 0), 10, 20, 30, 40, 1, 1, 1, 1));
27+
28+
//
29+
// AVX256 Masked Insert Test
30+
//
31+
32+
__m256 test_mm256_maskz_insertf32x4(__mmask8 U, __m256 A, __m128 B) {
33+
// CHECK-LABEL: test_mm256_maskz_insertf32x4
34+
return _mm256_maskz_insertf32x4(U, A, B, 1);
35+
}
36+
37+
// Test zero mask produces all zeros
38+
TEST_CONSTEXPR(match_m256(
39+
_mm256_maskz_insertf32x4(0x00, _mm256_set1_ps(1.0f),
40+
_mm_set_ps(40.0f, 30.0f, 20.0f, 10.0f), 1),
41+
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f));
42+
43+
//
44+
// AVX Legacy Insert Test
45+
//
46+
47+
__m256 test_mm256_insertf128_ps(__m256 A, __m128 B) {
48+
// CHECK-LABEL: test_mm256_insertf128_ps
49+
return _mm256_insertf128_ps(A, B, 1);
50+
}
51+
52+
// Legacy insertf128 into upper lane
53+
TEST_CONSTEXPR(match_m256(_mm256_insertf128_ps(_mm256_set1_ps(1.0f), _mm_set1_ps(7.0f), 1), 1.0f, 1.0f, 1.0f, 1.0f, 7.0f, 7.0f, 7.0f, 7.0f));
54+
55+
//
56+
//AVX512 Insert Tests
57+
//
58+
59+
__m512 test_mm512_insertf32x4(__m512 A, __m128 B) {
60+
// CHECK-LABEL: test_mm512_insertf32x4
61+
return _mm512_insertf32x4(A, B, 3);
62+
}
63+
64+
// Insert 128-bit into highest lane of 512-bit vector
65+
TEST_CONSTEXPR(match_m512(_mm512_insertf32x4(_mm512_set1_ps(1.0f), _mm_set1_ps(5.0f), 3), 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 5.0f, 5.0f, 5.0f, 5.0f));
66+
67+
__m512 test_mm512_insertf32x8(__m512 A, __m256 B) {
68+
// CHECK-LABEL: test_mm512_insertf32x8
69+
return _mm512_insertf32x8(A, B, 1);
70+
}
71+
72+
// Insert 256-bit into upper half of 512-bit vector
73+
TEST_CONSTEXPR(match_m512(_mm512_insertf32x8(_mm512_set1_ps(1.0f), _mm256_set1_ps(2.0f), 1), 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f));
74+
75+
//
76+
// AVX512 Masked Insert Test
77+
//
78+
79+
__m512 test_mm512_maskz_insertf32x4(__mmask16 U, __m512 A, __m128 B) {
80+
// CHECK-LABEL: test_mm512_maskz_insertf32x4
81+
return _mm512_maskz_insertf32x4(U, A, B, 3);
82+
}
83+
84+
// Test zero mask produces all zeros
85+
TEST_CONSTEXPR(match_m512(
86+
_mm512_maskz_insertf32x4(0x0000, _mm512_set1_ps(1.0f), _mm_set1_ps(5.0f), 3),
87+
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f));

0 commit comments

Comments
 (0)