Skip to content

Commit 4ce7ac1

Browse files
committed
[Clang] VectorExprEvaluator::VisitCallExpr / InterpretBuiltin - Allow GFNI intrinsics to be used in constexpr
Resolves:#169295
1 parent 6459f39 commit 4ce7ac1

File tree

6 files changed

+697
-56
lines changed

6 files changed

+697
-56
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -408,39 +408,21 @@ let Features = "avx512f,vaes", Attributes = [NoThrow, Const, RequiredVectorWidth
408408
def aesdeclast512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<8, long long int>)">;
409409
}
410410

411-
let Features = "gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
411+
let Features = "gfni", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
412412
def vgf2p8affineinvqb_v16qi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>, _Constant char)">;
413-
}
414-
415-
let Features = "avx,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
416-
def vgf2p8affineinvqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
417-
}
418-
419-
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
420-
def vgf2p8affineinvqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
421-
}
422-
423-
let Features = "gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
424413
def vgf2p8affineqb_v16qi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>, _Constant char)">;
425-
}
426-
427-
let Features = "avx,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
428-
def vgf2p8affineqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
429-
}
430-
431-
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
432-
def vgf2p8affineqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
433-
}
434-
435-
let Features = "gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
436414
def vgf2p8mulb_v16qi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>)">;
437415
}
438416

439-
let Features = "avx,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
417+
let Features = "avx,gfni", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
418+
def vgf2p8affineinvqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
419+
def vgf2p8affineqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
440420
def vgf2p8mulb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>)">;
441421
}
442422

443-
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
423+
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
424+
def vgf2p8affineinvqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
425+
def vgf2p8affineqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
444426
def vgf2p8mulb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>)">;
445427
}
446428

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3531,6 +3531,100 @@ static bool interp__builtin_ia32_shufbitqmb_mask(InterpState &S, CodePtr OpPC,
35313531
return true;
35323532
}
35333533

3534+
static bool interp_builtin_ia32_gfni_affine(InterpState &S, CodePtr OpPC,
3535+
const CallExpr *Call,
3536+
bool Inverse) {
3537+
assert(Call->getNumArgs() == 3);
3538+
QualType XType = Call->getArg(0)->getType();
3539+
QualType AType = Call->getArg(1)->getType();
3540+
QualType ImmType = Call->getArg(2)->getType();
3541+
if (!XType->isVectorType() || !AType->isVectorType() ||
3542+
!ImmType->isIntegerType()) {
3543+
return false;
3544+
}
3545+
3546+
Pointer X, A;
3547+
APSInt Imm = popToAPSInt(S, Call->getArg(2));
3548+
A = S.Stk.pop<Pointer>();
3549+
X = S.Stk.pop<Pointer>();
3550+
3551+
const Pointer &Dst = S.Stk.peek<Pointer>();
3552+
const auto *XVecT = XType->castAs<VectorType>();
3553+
const auto *AVecT = AType->castAs<VectorType>();
3554+
assert(XVecT->getNumElements() == AVecT->getNumElements());
3555+
unsigned NumBytesInQWord = 8;
3556+
unsigned NumBytes = AVecT->getNumElements();
3557+
unsigned NumBitsInQWord = 64;
3558+
unsigned NumQWords = NumBytes / NumBytesInQWord;
3559+
unsigned NumBitsInByte = 8;
3560+
PrimType AElemT = *S.getContext().classify(AVecT->getElementType());
3561+
3562+
// computing A*X + Imm
3563+
for (unsigned QWordIdx = 0; QWordIdx != NumQWords; ++QWordIdx) {
3564+
// Extract the QWords from X, A
3565+
APInt XQWord(NumBitsInQWord, 0);
3566+
APInt AQWord(NumBitsInQWord, 0);
3567+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
3568+
unsigned Idx = QWordIdx * NumBytesInQWord + ByteIdx;
3569+
uint8_t XByte;
3570+
uint8_t AByte;
3571+
INT_TYPE_SWITCH(AElemT, {
3572+
XByte = static_cast<uint8_t>(X.elem<T>(Idx));
3573+
AByte = static_cast<uint8_t>(A.elem<T>(Idx));
3574+
});
3575+
3576+
XQWord.insertBits(APInt(NumBitsInByte, XByte), ByteIdx * NumBitsInByte);
3577+
AQWord.insertBits(APInt(NumBitsInByte, AByte), ByteIdx * NumBitsInByte);
3578+
}
3579+
3580+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
3581+
unsigned Idx = QWordIdx * NumBytesInQWord + ByteIdx;
3582+
uint8_t XByte =
3583+
XQWord.lshr(ByteIdx * NumBitsInByte).getLoBits(8).getZExtValue();
3584+
INT_TYPE_SWITCH(AElemT, {
3585+
Dst.elem<T>(Idx) = T::from(GFNIAffine(XByte, AQWord, Imm, Inverse));
3586+
});
3587+
}
3588+
}
3589+
Dst.initializeAllElements();
3590+
return true;
3591+
}
3592+
3593+
static bool interp__builtin_ia32_gfni_mul(InterpState &S, CodePtr OpPC,
3594+
const CallExpr *Call) {
3595+
assert(Call->getNumArgs() == 2);
3596+
3597+
QualType AType = Call->getArg(0)->getType();
3598+
QualType BType = Call->getArg(1)->getType();
3599+
if (!AType->isVectorType() || !BType->isVectorType()) {
3600+
return false;
3601+
}
3602+
3603+
Pointer A, B;
3604+
B = S.Stk.pop<Pointer>();
3605+
A = S.Stk.pop<Pointer>();
3606+
3607+
const Pointer &Dst = S.Stk.peek<Pointer>();
3608+
const auto *AVecT = AType->castAs<VectorType>();
3609+
const auto *BVecT = BType->castAs<VectorType>();
3610+
assert(AVecT->getNumElements() == BVecT->getNumElements());
3611+
3612+
PrimType AElemT = *S.getContext().classify(AVecT->getElementType());
3613+
unsigned NumBytes = A.getNumElems();
3614+
3615+
for (unsigned ByteIdx = 0; ByteIdx != NumBytes; ++ByteIdx) {
3616+
uint8_t AByte, BByte;
3617+
INT_TYPE_SWITCH(AElemT, {
3618+
AByte = static_cast<uint8_t>(A.elem<T>(ByteIdx));
3619+
BByte = static_cast<uint8_t>(B.elem<T>(ByteIdx));
3620+
Dst.elem<T>(ByteIdx) = T::from(GFNIMul(AByte, BByte));
3621+
});
3622+
}
3623+
3624+
Dst.initializeAllElements();
3625+
return true;
3626+
}
3627+
35343628
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
35353629
uint32_t BuiltinID) {
35363630
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -4545,6 +4639,21 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
45454639
return std::pair<unsigned, int>{SrcIdx,
45464640
static_cast<int>(LaneOffset + Index)};
45474641
});
4642+
4643+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v16qi:
4644+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v32qi:
4645+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v64qi:
4646+
return interp_builtin_ia32_gfni_affine(S, OpPC, Call, true);
4647+
case X86::BI__builtin_ia32_vgf2p8affineqb_v16qi:
4648+
case X86::BI__builtin_ia32_vgf2p8affineqb_v32qi:
4649+
case X86::BI__builtin_ia32_vgf2p8affineqb_v64qi:
4650+
return interp_builtin_ia32_gfni_affine(S, OpPC, Call, false);
4651+
4652+
case X86::BI__builtin_ia32_vgf2p8mulb_v16qi:
4653+
case X86::BI__builtin_ia32_vgf2p8mulb_v32qi:
4654+
case X86::BI__builtin_ia32_vgf2p8mulb_v64qi:
4655+
return interp__builtin_ia32_gfni_mul(S, OpPC, Call);
4656+
45484657
case X86::BI__builtin_ia32_insertps128:
45494658
return interp__builtin_ia32_shuffle_generic(
45504659
S, OpPC, Call, [](unsigned DstIdx, unsigned Mask) {

clang/lib/AST/ExprConstShared.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
#define LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
1616

1717
#include "clang/Basic/TypeTraits.h"
18+
#include <cstdint>
1819

1920
namespace llvm {
2021
class APFloat;
22+
class APInt;
23+
class APSInt;
2124
}
2225
namespace clang {
2326
class QualType;
@@ -74,4 +77,9 @@ void HandleComplexComplexDiv(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
7477
CharUnits GetAlignOfExpr(const ASTContext &Ctx, const Expr *E,
7578
UnaryExprOrTypeTrait ExprKind);
7679

80+
uint8_t GFNIMultiplicativeInverse(uint8_t Byte);
81+
uint8_t GFNIMul(uint8_t AByte, uint8_t BByte);
82+
uint8_t GFNIAffine(uint8_t XByte, llvm::APInt &AQword, llvm::APSInt Imm,
83+
bool Inverse = false);
84+
7785
#endif

clang/lib/AST/ExprConstant.cpp

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13517,6 +13517,89 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1351713517
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
1351813518
}
1351913519

13520+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v16qi:
13521+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v32qi:
13522+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v64qi:
13523+
case X86::BI__builtin_ia32_vgf2p8affineqb_v16qi:
13524+
case X86::BI__builtin_ia32_vgf2p8affineqb_v32qi:
13525+
case X86::BI__builtin_ia32_vgf2p8affineqb_v64qi: {
13526+
13527+
APValue X, A;
13528+
APSInt Imm;
13529+
if (!EvaluateAsRValue(Info, E->getArg(0), X) ||
13530+
!EvaluateAsRValue(Info, E->getArg(1), A) ||
13531+
!EvaluateInteger(E->getArg(2), Imm, Info))
13532+
return false;
13533+
13534+
assert(X.isVector() && A.isVector());
13535+
assert(X.getVectorLength() == A.getVectorLength());
13536+
13537+
bool IsInverse = false;
13538+
switch (E->getBuiltinCallee()) {
13539+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v16qi:
13540+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v32qi:
13541+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v64qi: {
13542+
IsInverse = true;
13543+
}
13544+
}
13545+
13546+
unsigned NumBitsInByte = 8;
13547+
unsigned NumBytesInQWord = 8;
13548+
unsigned NumBitsInQWord = 64;
13549+
unsigned NumBytes = A.getVectorLength();
13550+
unsigned NumQWords = NumBytes / NumBytesInQWord;
13551+
SmallVector<APValue, 64> Result;
13552+
Result.reserve(NumBytes);
13553+
13554+
// computing A*X + Imm
13555+
for (unsigned QWordIdx = 0; QWordIdx != NumQWords; ++QWordIdx) {
13556+
// Extract the QWords from X, A
13557+
APInt XQWord(NumBitsInQWord, 0);
13558+
APInt AQWord(NumBitsInQWord, 0);
13559+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
13560+
unsigned Idx = QWordIdx * NumBytesInQWord + ByteIdx;
13561+
APInt XByte = X.getVectorElt(Idx).getInt();
13562+
APInt AByte = A.getVectorElt(Idx).getInt();
13563+
XQWord.insertBits(XByte, ByteIdx * NumBitsInByte);
13564+
AQWord.insertBits(AByte, ByteIdx * NumBitsInByte);
13565+
}
13566+
13567+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
13568+
uint8_t XByte =
13569+
XQWord.lshr(ByteIdx * NumBitsInByte).getLoBits(8).getZExtValue();
13570+
Result.push_back(APValue(APSInt(
13571+
APInt(8, GFNIAffine(XByte, AQWord, Imm, IsInverse)), false)));
13572+
}
13573+
}
13574+
13575+
return Success(APValue(Result.data(), Result.size()), E);
13576+
}
13577+
13578+
case X86::BI__builtin_ia32_vgf2p8mulb_v16qi:
13579+
case X86::BI__builtin_ia32_vgf2p8mulb_v32qi:
13580+
case X86::BI__builtin_ia32_vgf2p8mulb_v64qi: {
13581+
APValue A, B;
13582+
if (!EvaluateAsRValue(Info, E->getArg(0), A) ||
13583+
!EvaluateAsRValue(Info, E->getArg(1), B))
13584+
return false;
13585+
13586+
assert(A.isVector() && B.isVector());
13587+
assert(A.getVectorLength() == B.getVectorLength());
13588+
13589+
unsigned NumBytes = A.getVectorLength();
13590+
SmallVector<APValue, 64> Result;
13591+
Result.reserve(NumBytes);
13592+
13593+
for (unsigned ByteIdx = 0; ByteIdx != NumBytes; ++ByteIdx) {
13594+
uint8_t AByte = A.getVectorElt(ByteIdx).getInt().getZExtValue();
13595+
uint8_t BByte = B.getVectorElt(ByteIdx).getInt().getZExtValue();
13596+
Result.push_back(APValue(
13597+
APSInt(APInt(8, GFNIMul(AByte, BByte)), /*IsUnsigned=*/false)));
13598+
}
13599+
13600+
return Success(APValue(Result.data(), Result.size()), E);
13601+
}
13602+
1352013603
case X86::BI__builtin_ia32_insertf32x4_256:
1352113604
case X86::BI__builtin_ia32_inserti32x4_256:
1352213605
case X86::BI__builtin_ia32_insertf64x2_256:
@@ -19008,6 +19091,87 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
1900819091
llvm_unreachable("unknown cast resulting in complex value");
1900919092
}
1901019093

19094+
uint8_t GFNIMultiplicativeInverse(uint8_t Byte) {
19095+
// Lookup Table for Multiplicative Inverse in GF(2^8)
19096+
const uint8_t GFInv[256] = {
19097+
0x00, 0x01, 0x8d, 0xf6, 0xcb, 0x52, 0x7b, 0xd1, 0xe8, 0x4f, 0x29, 0xc0,
19098+
0xb0, 0xe1, 0xe5, 0xc7, 0x74, 0xb4, 0xaa, 0x4b, 0x99, 0x2b, 0x60, 0x5f,
19099+
0x58, 0x3f, 0xfd, 0xcc, 0xff, 0x40, 0xee, 0xb2, 0x3a, 0x6e, 0x5a, 0xf1,
19100+
0x55, 0x4d, 0xa8, 0xc9, 0xc1, 0x0a, 0x98, 0x15, 0x30, 0x44, 0xa2, 0xc2,
19101+
0x2c, 0x45, 0x92, 0x6c, 0xf3, 0x39, 0x66, 0x42, 0xf2, 0x35, 0x20, 0x6f,
19102+
0x77, 0xbb, 0x59, 0x19, 0x1d, 0xfe, 0x37, 0x67, 0x2d, 0x31, 0xf5, 0x69,
19103+
0xa7, 0x64, 0xab, 0x13, 0x54, 0x25, 0xe9, 0x09, 0xed, 0x5c, 0x05, 0xca,
19104+
0x4c, 0x24, 0x87, 0xbf, 0x18, 0x3e, 0x22, 0xf0, 0x51, 0xec, 0x61, 0x17,
19105+
0x16, 0x5e, 0xaf, 0xd3, 0x49, 0xa6, 0x36, 0x43, 0xf4, 0x47, 0x91, 0xdf,
19106+
0x33, 0x93, 0x21, 0x3b, 0x79, 0xb7, 0x97, 0x85, 0x10, 0xb5, 0xba, 0x3c,
19107+
0xb6, 0x70, 0xd0, 0x06, 0xa1, 0xfa, 0x81, 0x82, 0x83, 0x7e, 0x7f, 0x80,
19108+
0x96, 0x73, 0xbe, 0x56, 0x9b, 0x9e, 0x95, 0xd9, 0xf7, 0x02, 0xb9, 0xa4,
19109+
0xde, 0x6a, 0x32, 0x6d, 0xd8, 0x8a, 0x84, 0x72, 0x2a, 0x14, 0x9f, 0x88,
19110+
0xf9, 0xdc, 0x89, 0x9a, 0xfb, 0x7c, 0x2e, 0xc3, 0x8f, 0xb8, 0x65, 0x48,
19111+
0x26, 0xc8, 0x12, 0x4a, 0xce, 0xe7, 0xd2, 0x62, 0x0c, 0xe0, 0x1f, 0xef,
19112+
0x11, 0x75, 0x78, 0x71, 0xa5, 0x8e, 0x76, 0x3d, 0xbd, 0xbc, 0x86, 0x57,
19113+
0x0b, 0x28, 0x2f, 0xa3, 0xda, 0xd4, 0xe4, 0x0f, 0xa9, 0x27, 0x53, 0x04,
19114+
0x1b, 0xfc, 0xac, 0xe6, 0x7a, 0x07, 0xae, 0x63, 0xc5, 0xdb, 0xe2, 0xea,
19115+
0x94, 0x8b, 0xc4, 0xd5, 0x9d, 0xf8, 0x90, 0x6b, 0xb1, 0x0d, 0xd6, 0xeb,
19116+
0xc6, 0x0e, 0xcf, 0xad, 0x08, 0x4e, 0xd7, 0xe3, 0x5d, 0x50, 0x1e, 0xb3,
19117+
0x5b, 0x23, 0x38, 0x34, 0x68, 0x46, 0x03, 0x8c, 0xdd, 0x9c, 0x7d, 0xa0,
19118+
0xcd, 0x1a, 0x41, 0x1c};
19119+
19120+
return GFInv[Byte];
19121+
}
19122+
19123+
uint8_t GFNIAffine(uint8_t XByte, APInt &AQword, APSInt Imm, bool Inverse) {
19124+
unsigned NumBitsInByte = 8;
19125+
// Computing the affine transformation
19126+
uint8_t RetByte = 0;
19127+
for (uint32_t BitIdx = 0; BitIdx != NumBitsInByte; ++BitIdx) {
19128+
uint8_t AByte =
19129+
AQword.lshr((7 - static_cast<int32_t>(BitIdx)) * NumBitsInByte)
19130+
.getLoBits(8)
19131+
.getZExtValue();
19132+
uint8_t Product;
19133+
if (Inverse) {
19134+
Product = AByte & GFNIMultiplicativeInverse(XByte);
19135+
} else {
19136+
Product = AByte & XByte;
19137+
}
19138+
uint8_t Parity = 0;
19139+
19140+
// Dot product in GF(2) uses XOR instead of addition
19141+
for (unsigned PBitIdx = 0; PBitIdx != NumBitsInByte; ++PBitIdx) {
19142+
Parity = Parity ^ ((Product >> PBitIdx) & 0x1);
19143+
}
19144+
19145+
uint8_t Temp = Imm[BitIdx] ? 1 : 0;
19146+
RetByte |= (Temp ^ Parity) << BitIdx;
19147+
}
19148+
return RetByte;
19149+
}
19150+
19151+
uint8_t GFNIMul(uint8_t AByte, uint8_t BByte) {
19152+
// Multiplying two polynomials of degree 7
19153+
// Polynomial of degree 7
19154+
// x^7 + x^6 + x^5 + x^4 + x^3 + x^2 + x + 1
19155+
uint16_t TWord = 0;
19156+
unsigned NumBitsInByte = 8;
19157+
for (unsigned BitIdx = 0; BitIdx != NumBitsInByte; ++BitIdx) {
19158+
if ((BByte >> BitIdx) & 0x1) {
19159+
TWord = TWord ^ (AByte << BitIdx);
19160+
}
19161+
}
19162+
19163+
// When multiplying two polynomials of degree 7
19164+
// results in a polynomial of degree 14
19165+
// so the result has to be reduced to 7
19166+
// Reduction polynomial is x^8 + x^4 + x^3 + x + 1 i.e. 0x11B
19167+
for (int32_t BitIdx = 14; BitIdx > 7; --BitIdx) {
19168+
if ((TWord >> BitIdx) & 0x1) {
19169+
TWord = TWord ^ (0x11B << (BitIdx - 8));
19170+
}
19171+
}
19172+
return (TWord & 0xFF);
19173+
}
19174+
1901119175
void HandleComplexComplexMul(APFloat A, APFloat B, APFloat C, APFloat D,
1901219176
APFloat &ResR, APFloat &ResI) {
1901319177
// This is an implementation of complex multiplication according to the

0 commit comments

Comments
 (0)