Skip to content

Commit d7feeda

Browse files
kimsh02RKSimon
andauthored
[Clang] VectorExprEvaluator::VisitCallExpr / InterpretBuiltin - add AVX512 VPTERNLOGD/VPTERNLOGQ intrinsics to be used in constexpr (#158703)
Fix #157698 Add handling for `__builtin_ia32_pternlog[d/q][128/256/512]_mask[z]` intrinsics to `VectorExprEvaluator::VisitCallExpr` and `InterpBuiltin.cpp` with the corresponding test coverage: ``` _mm_mask_ternarylogic_epi32 _mm_maskz_ternarylogic_epi32 _mm_ternarylogic_epi32 _mm256_mask_ternarylogic_epi32 _mm256_maskz_ternarylogic_epi32 _mm256_ternarylogic_epi32 _mm512_mask_ternarylogic_epi32 _mm512_maskz_ternarylogic_epi32 _mm512_ternarylogic_epi32 _mm_mask_ternarylogic_epi64 _mm_maskz_ternarylogic_epi64 _mm_ternarylogic_epi64 _mm256_mask_ternarylogic_epi64 _mm256_maskz_ternarylogic_epi64 _mm256_ternarylogic_epi64 _mm512_mask_ternarylogic_epi64 _mm512_maskz_ternarylogic_epi64 _mm512_ternarylogic_epi64 ``` --------- Co-authored-by: Simon Pilgrim <[email protected]>
1 parent f23c0e6 commit d7feeda

File tree

5 files changed

+800
-4
lines changed

5 files changed

+800
-4
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,28 +2409,36 @@ let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>
24092409
def psraq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<2, long long int>)">;
24102410
def psrld512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<4, int>)">;
24112411
def psrlq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<2, long long int>)">;
2412+
}
2413+
2414+
let Features = "avx512f",
2415+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
24122416
def pternlogd512_mask : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<16, int>, _Vector<16, int>, _Constant int, unsigned short)">;
24132417
def pternlogd512_maskz : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<16, int>, _Vector<16, int>, _Constant int, unsigned short)">;
24142418
def pternlogq512_mask : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<8, long long int>, _Vector<8, long long int>, _Constant int, unsigned char)">;
24152419
def pternlogq512_maskz : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<8, long long int>, _Vector<8, long long int>, _Constant int, unsigned char)">;
24162420
}
24172421

2418-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
2422+
let Features = "avx512vl",
2423+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
24192424
def pternlogd128_mask : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>, _Vector<4, int>, _Constant int, unsigned char)">;
24202425
def pternlogd128_maskz : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>, _Vector<4, int>, _Constant int, unsigned char)">;
24212426
}
24222427

2423-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
2428+
let Features = "avx512vl",
2429+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
24242430
def pternlogd256_mask : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>, _Vector<8, int>, _Constant int, unsigned char)">;
24252431
def pternlogd256_maskz : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>, _Vector<8, int>, _Constant int, unsigned char)">;
24262432
}
24272433

2428-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
2434+
let Features = "avx512vl",
2435+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
24292436
def pternlogq128_mask : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>, _Vector<2, long long int>, _Constant int, unsigned char)">;
24302437
def pternlogq128_maskz : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>, _Vector<2, long long int>, _Constant int, unsigned char)">;
24312438
}
24322439

2433-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
2440+
let Features = "avx512vl",
2441+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
24342442
def pternlogq256_mask : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>, _Vector<4, long long int>, _Constant int, unsigned char)">;
24352443
def pternlogq256_maskz : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>, _Vector<4, long long int>, _Constant int, unsigned char)">;
24362444
}

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2897,7 +2897,49 @@ static bool interp__builtin_x86_insert_subvector(InterpState &S, CodePtr OpPC,
28972897
});
28982898

28992899
Dst.initializeAllElements();
2900+
return true;
2901+
}
2902+
2903+
static bool interp__builtin_ia32_pternlog(InterpState &S, CodePtr OpPC,
2904+
const CallExpr *Call, bool MaskZ) {
2905+
assert(Call->getNumArgs() == 5);
2906+
2907+
APInt U = popToAPSInt(S, Call->getArg(4)); // Lane mask
2908+
APInt Imm = popToAPSInt(S, Call->getArg(3)); // Ternary truth table
2909+
const Pointer &C = S.Stk.pop<Pointer>();
2910+
const Pointer &B = S.Stk.pop<Pointer>();
2911+
const Pointer &A = S.Stk.pop<Pointer>();
2912+
const Pointer &Dst = S.Stk.peek<Pointer>();
29002913

2914+
unsigned DstLen = A.getNumElems();
2915+
const QualType ElemQT = getElemType(A);
2916+
const OptPrimType ElemPT = S.getContext().classify(ElemQT);
2917+
unsigned LaneWidth = S.getASTContext().getTypeSize(ElemQT);
2918+
bool DstUnsigned = ElemQT->isUnsignedIntegerOrEnumerationType();
2919+
2920+
INT_TYPE_SWITCH_NO_BOOL(*ElemPT, {
2921+
for (unsigned I = 0; I != DstLen; ++I) {
2922+
APInt ALane = A.elem<T>(I).toAPSInt();
2923+
APInt BLane = B.elem<T>(I).toAPSInt();
2924+
APInt CLane = C.elem<T>(I).toAPSInt();
2925+
APInt RLane(LaneWidth, 0);
2926+
if (U[I]) { // If lane not masked, compute ternary logic.
2927+
for (unsigned Bit = 0; Bit != LaneWidth; ++Bit) {
2928+
unsigned ABit = ALane[Bit];
2929+
unsigned BBit = BLane[Bit];
2930+
unsigned CBit = CLane[Bit];
2931+
unsigned Idx = (ABit << 2) | (BBit << 1) | (CBit);
2932+
RLane.setBitVal(Bit, Imm[Idx]);
2933+
}
2934+
Dst.elem<T>(I) = static_cast<T>(APSInt(RLane, DstUnsigned));
2935+
} else if (MaskZ) { // If zero masked, zero the lane.
2936+
Dst.elem<T>(I) = static_cast<T>(APSInt(RLane, DstUnsigned));
2937+
} else { // Just masked, put in A lane.
2938+
Dst.elem<T>(I) = static_cast<T>(APSInt(ALane, DstUnsigned));
2939+
}
2940+
}
2941+
});
2942+
Dst.initializeAllElements();
29012943
return true;
29022944
}
29032945

@@ -3760,6 +3802,20 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
37603802
S, OpPC, Call,
37613803
[](const APSInt &LHS, const APSInt &RHS) { return LHS + RHS; });
37623804

3805+
case X86::BI__builtin_ia32_pternlogd128_mask:
3806+
case X86::BI__builtin_ia32_pternlogd256_mask:
3807+
case X86::BI__builtin_ia32_pternlogd512_mask:
3808+
case X86::BI__builtin_ia32_pternlogq128_mask:
3809+
case X86::BI__builtin_ia32_pternlogq256_mask:
3810+
case X86::BI__builtin_ia32_pternlogq512_mask:
3811+
return interp__builtin_ia32_pternlog(S, OpPC, Call, /*MaskZ=*/false);
3812+
case X86::BI__builtin_ia32_pternlogd128_maskz:
3813+
case X86::BI__builtin_ia32_pternlogd256_maskz:
3814+
case X86::BI__builtin_ia32_pternlogd512_maskz:
3815+
case X86::BI__builtin_ia32_pternlogq128_maskz:
3816+
case X86::BI__builtin_ia32_pternlogq256_maskz:
3817+
case X86::BI__builtin_ia32_pternlogq512_maskz:
3818+
return interp__builtin_ia32_pternlog(S, OpPC, Call, /*MaskZ=*/true);
37633819
case Builtin::BI__builtin_elementwise_fshl:
37643820
return interp__builtin_elementwise_triop(S, OpPC, Call,
37653821
llvm::APIntOps::fshl);

clang/lib/AST/ExprConstant.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12168,6 +12168,97 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1216812168
return Success(R, E);
1216912169
}
1217012170

12171+
case X86::BI__builtin_ia32_pternlogd128_mask:
12172+
case X86::BI__builtin_ia32_pternlogd256_mask:
12173+
case X86::BI__builtin_ia32_pternlogd512_mask:
12174+
case X86::BI__builtin_ia32_pternlogq128_mask:
12175+
case X86::BI__builtin_ia32_pternlogq256_mask:
12176+
case X86::BI__builtin_ia32_pternlogq512_mask: {
12177+
APValue AValue, BValue, CValue, ImmValue, UValue;
12178+
if (!EvaluateAsRValue(Info, E->getArg(0), AValue) ||
12179+
!EvaluateAsRValue(Info, E->getArg(1), BValue) ||
12180+
!EvaluateAsRValue(Info, E->getArg(2), CValue) ||
12181+
!EvaluateAsRValue(Info, E->getArg(3), ImmValue) ||
12182+
!EvaluateAsRValue(Info, E->getArg(4), UValue))
12183+
return false;
12184+
12185+
QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType();
12186+
bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType();
12187+
APInt Imm = ImmValue.getInt();
12188+
APInt U = UValue.getInt();
12189+
unsigned ResultLen = AValue.getVectorLength();
12190+
SmallVector<APValue, 16> ResultElements;
12191+
ResultElements.reserve(ResultLen);
12192+
12193+
for (unsigned EltNum = 0; EltNum < ResultLen; ++EltNum) {
12194+
APInt ALane = AValue.getVectorElt(EltNum).getInt();
12195+
APInt BLane = BValue.getVectorElt(EltNum).getInt();
12196+
APInt CLane = CValue.getVectorElt(EltNum).getInt();
12197+
12198+
if (U[EltNum]) {
12199+
unsigned BitWidth = ALane.getBitWidth();
12200+
APInt ResLane(BitWidth, 0);
12201+
12202+
for (unsigned Bit = 0; Bit < BitWidth; ++Bit) {
12203+
unsigned ABit = ALane[Bit];
12204+
unsigned BBit = BLane[Bit];
12205+
unsigned CBit = CLane[Bit];
12206+
12207+
unsigned Idx = (ABit << 2) | (BBit << 1) | CBit;
12208+
ResLane.setBitVal(Bit, Imm[Idx]);
12209+
}
12210+
ResultElements.push_back(APValue(APSInt(ResLane, DestUnsigned)));
12211+
} else {
12212+
ResultElements.push_back(APValue(APSInt(ALane, DestUnsigned)));
12213+
}
12214+
}
12215+
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
12216+
}
12217+
case X86::BI__builtin_ia32_pternlogd128_maskz:
12218+
case X86::BI__builtin_ia32_pternlogd256_maskz:
12219+
case X86::BI__builtin_ia32_pternlogd512_maskz:
12220+
case X86::BI__builtin_ia32_pternlogq128_maskz:
12221+
case X86::BI__builtin_ia32_pternlogq256_maskz:
12222+
case X86::BI__builtin_ia32_pternlogq512_maskz: {
12223+
APValue AValue, BValue, CValue, ImmValue, UValue;
12224+
if (!EvaluateAsRValue(Info, E->getArg(0), AValue) ||
12225+
!EvaluateAsRValue(Info, E->getArg(1), BValue) ||
12226+
!EvaluateAsRValue(Info, E->getArg(2), CValue) ||
12227+
!EvaluateAsRValue(Info, E->getArg(3), ImmValue) ||
12228+
!EvaluateAsRValue(Info, E->getArg(4), UValue))
12229+
return false;
12230+
12231+
QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType();
12232+
bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType();
12233+
APInt Imm = ImmValue.getInt();
12234+
APInt U = UValue.getInt();
12235+
unsigned ResultLen = AValue.getVectorLength();
12236+
SmallVector<APValue, 16> ResultElements;
12237+
ResultElements.reserve(ResultLen);
12238+
12239+
for (unsigned EltNum = 0; EltNum < ResultLen; ++EltNum) {
12240+
APInt ALane = AValue.getVectorElt(EltNum).getInt();
12241+
APInt BLane = BValue.getVectorElt(EltNum).getInt();
12242+
APInt CLane = CValue.getVectorElt(EltNum).getInt();
12243+
12244+
unsigned BitWidth = ALane.getBitWidth();
12245+
APInt ResLane(BitWidth, 0);
12246+
12247+
if (U[EltNum]) {
12248+
for (unsigned Bit = 0; Bit < BitWidth; ++Bit) {
12249+
unsigned ABit = ALane[Bit];
12250+
unsigned BBit = BLane[Bit];
12251+
unsigned CBit = CLane[Bit];
12252+
12253+
unsigned Idx = (ABit << 2) | (BBit << 1) | CBit;
12254+
ResLane.setBitVal(Bit, Imm[Idx]);
12255+
}
12256+
}
12257+
ResultElements.push_back(APValue(APSInt(ResLane, DestUnsigned)));
12258+
}
12259+
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
12260+
}
12261+
1217112262
case Builtin::BI__builtin_elementwise_clzg:
1217212263
case Builtin::BI__builtin_elementwise_ctzg: {
1217312264
APValue SourceLHS;

0 commit comments

Comments
 (0)