Skip to content

Commit eb27e83

Browse files
committed
[Clang] VectorExprEvaluator::VisitCallExpr / InterpretBuiltin - Allow GFNI intrinsics to be used in constexpr
Resolves:#169295
1 parent 06c8ee6 commit eb27e83

File tree

6 files changed

+697
-56
lines changed

6 files changed

+697
-56
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -402,39 +402,21 @@ let Features = "avx512f,vaes", Attributes = [NoThrow, Const, RequiredVectorWidth
402402
def aesdeclast512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<8, long long int>)">;
403403
}
404404

405-
let Features = "gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
405+
let Features = "gfni", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
406406
def vgf2p8affineinvqb_v16qi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>, _Constant char)">;
407-
}
408-
409-
let Features = "avx,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
410-
def vgf2p8affineinvqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
411-
}
412-
413-
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
414-
def vgf2p8affineinvqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
415-
}
416-
417-
let Features = "gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
418407
def vgf2p8affineqb_v16qi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>, _Constant char)">;
419-
}
420-
421-
let Features = "avx,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
422-
def vgf2p8affineqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
423-
}
424-
425-
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
426-
def vgf2p8affineqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
427-
}
428-
429-
let Features = "gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
430408
def vgf2p8mulb_v16qi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>)">;
431409
}
432410

433-
let Features = "avx,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
411+
let Features = "avx,gfni", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
412+
def vgf2p8affineinvqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
413+
def vgf2p8affineqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
434414
def vgf2p8mulb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>)">;
435415
}
436416

437-
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
417+
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
418+
def vgf2p8affineinvqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
419+
def vgf2p8affineqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
438420
def vgf2p8mulb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>)">;
439421
}
440422

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3735,6 +3735,100 @@ static bool interp__builtin_ia32_multishiftqb(InterpState &S, CodePtr OpPC,
37353735
return true;
37363736
}
37373737

3738+
static bool interp_builtin_ia32_gfni_affine(InterpState &S, CodePtr OpPC,
3739+
const CallExpr *Call,
3740+
bool Inverse) {
3741+
assert(Call->getNumArgs() == 3);
3742+
QualType XType = Call->getArg(0)->getType();
3743+
QualType AType = Call->getArg(1)->getType();
3744+
QualType ImmType = Call->getArg(2)->getType();
3745+
if (!XType->isVectorType() || !AType->isVectorType() ||
3746+
!ImmType->isIntegerType()) {
3747+
return false;
3748+
}
3749+
3750+
Pointer X, A;
3751+
APSInt Imm = popToAPSInt(S, Call->getArg(2));
3752+
A = S.Stk.pop<Pointer>();
3753+
X = S.Stk.pop<Pointer>();
3754+
3755+
const Pointer &Dst = S.Stk.peek<Pointer>();
3756+
const auto *XVecT = XType->castAs<VectorType>();
3757+
const auto *AVecT = AType->castAs<VectorType>();
3758+
assert(XVecT->getNumElements() == AVecT->getNumElements());
3759+
unsigned NumBytesInQWord = 8;
3760+
unsigned NumBytes = AVecT->getNumElements();
3761+
unsigned NumBitsInQWord = 64;
3762+
unsigned NumQWords = NumBytes / NumBytesInQWord;
3763+
unsigned NumBitsInByte = 8;
3764+
PrimType AElemT = *S.getContext().classify(AVecT->getElementType());
3765+
3766+
// computing A*X + Imm
3767+
for (unsigned QWordIdx = 0; QWordIdx != NumQWords; ++QWordIdx) {
3768+
// Extract the QWords from X, A
3769+
APInt XQWord(NumBitsInQWord, 0);
3770+
APInt AQWord(NumBitsInQWord, 0);
3771+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
3772+
unsigned Idx = QWordIdx * NumBytesInQWord + ByteIdx;
3773+
uint8_t XByte;
3774+
uint8_t AByte;
3775+
INT_TYPE_SWITCH(AElemT, {
3776+
XByte = static_cast<uint8_t>(X.elem<T>(Idx));
3777+
AByte = static_cast<uint8_t>(A.elem<T>(Idx));
3778+
});
3779+
3780+
XQWord.insertBits(APInt(NumBitsInByte, XByte), ByteIdx * NumBitsInByte);
3781+
AQWord.insertBits(APInt(NumBitsInByte, AByte), ByteIdx * NumBitsInByte);
3782+
}
3783+
3784+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
3785+
unsigned Idx = QWordIdx * NumBytesInQWord + ByteIdx;
3786+
uint8_t XByte =
3787+
XQWord.lshr(ByteIdx * NumBitsInByte).getLoBits(8).getZExtValue();
3788+
INT_TYPE_SWITCH(AElemT, {
3789+
Dst.elem<T>(Idx) = T::from(GFNIAffine(XByte, AQWord, Imm, Inverse));
3790+
});
3791+
}
3792+
}
3793+
Dst.initializeAllElements();
3794+
return true;
3795+
}
3796+
3797+
static bool interp__builtin_ia32_gfni_mul(InterpState &S, CodePtr OpPC,
3798+
const CallExpr *Call) {
3799+
assert(Call->getNumArgs() == 2);
3800+
3801+
QualType AType = Call->getArg(0)->getType();
3802+
QualType BType = Call->getArg(1)->getType();
3803+
if (!AType->isVectorType() || !BType->isVectorType()) {
3804+
return false;
3805+
}
3806+
3807+
Pointer A, B;
3808+
B = S.Stk.pop<Pointer>();
3809+
A = S.Stk.pop<Pointer>();
3810+
3811+
const Pointer &Dst = S.Stk.peek<Pointer>();
3812+
const auto *AVecT = AType->castAs<VectorType>();
3813+
const auto *BVecT = BType->castAs<VectorType>();
3814+
assert(AVecT->getNumElements() == BVecT->getNumElements());
3815+
3816+
PrimType AElemT = *S.getContext().classify(AVecT->getElementType());
3817+
unsigned NumBytes = A.getNumElems();
3818+
3819+
for (unsigned ByteIdx = 0; ByteIdx != NumBytes; ++ByteIdx) {
3820+
uint8_t AByte, BByte;
3821+
INT_TYPE_SWITCH(AElemT, {
3822+
AByte = static_cast<uint8_t>(A.elem<T>(ByteIdx));
3823+
BByte = static_cast<uint8_t>(B.elem<T>(ByteIdx));
3824+
Dst.elem<T>(ByteIdx) = T::from(GFNIMul(AByte, BByte));
3825+
});
3826+
}
3827+
3828+
Dst.initializeAllElements();
3829+
return true;
3830+
}
3831+
37383832
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
37393833
uint32_t BuiltinID) {
37403834
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -4749,6 +4843,21 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
47494843
return std::pair<unsigned, int>{SrcIdx,
47504844
static_cast<int>(LaneOffset + Index)};
47514845
});
4846+
4847+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v16qi:
4848+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v32qi:
4849+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v64qi:
4850+
return interp_builtin_ia32_gfni_affine(S, OpPC, Call, true);
4851+
case X86::BI__builtin_ia32_vgf2p8affineqb_v16qi:
4852+
case X86::BI__builtin_ia32_vgf2p8affineqb_v32qi:
4853+
case X86::BI__builtin_ia32_vgf2p8affineqb_v64qi:
4854+
return interp_builtin_ia32_gfni_affine(S, OpPC, Call, false);
4855+
4856+
case X86::BI__builtin_ia32_vgf2p8mulb_v16qi:
4857+
case X86::BI__builtin_ia32_vgf2p8mulb_v32qi:
4858+
case X86::BI__builtin_ia32_vgf2p8mulb_v64qi:
4859+
return interp__builtin_ia32_gfni_mul(S, OpPC, Call);
4860+
47524861
case X86::BI__builtin_ia32_insertps128:
47534862
return interp__builtin_ia32_shuffle_generic(
47544863
S, OpPC, Call, [](unsigned DstIdx, unsigned Mask) {

clang/lib/AST/ExprConstShared.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
#define LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
1616

1717
#include "clang/Basic/TypeTraits.h"
18+
#include <cstdint>
1819

1920
namespace llvm {
2021
class APFloat;
22+
class APInt;
23+
class APSInt;
2124
}
2225
namespace clang {
2326
class QualType;
@@ -74,4 +77,9 @@ void HandleComplexComplexDiv(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
7477
CharUnits GetAlignOfExpr(const ASTContext &Ctx, const Expr *E,
7578
UnaryExprOrTypeTrait ExprKind);
7679

80+
uint8_t GFNIMultiplicativeInverse(uint8_t Byte);
81+
uint8_t GFNIMul(uint8_t AByte, uint8_t BByte);
82+
uint8_t GFNIAffine(uint8_t XByte, llvm::APInt &AQword, llvm::APSInt Imm,
83+
bool Inverse = false);
84+
7785
#endif

clang/lib/AST/ExprConstant.cpp

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13712,6 +13712,89 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1371213712
return Success(R, E);
1371313713
}
1371413714

13715+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v16qi:
13716+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v32qi:
13717+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v64qi:
13718+
case X86::BI__builtin_ia32_vgf2p8affineqb_v16qi:
13719+
case X86::BI__builtin_ia32_vgf2p8affineqb_v32qi:
13720+
case X86::BI__builtin_ia32_vgf2p8affineqb_v64qi: {
13721+
13722+
APValue X, A;
13723+
APSInt Imm;
13724+
if (!EvaluateAsRValue(Info, E->getArg(0), X) ||
13725+
!EvaluateAsRValue(Info, E->getArg(1), A) ||
13726+
!EvaluateInteger(E->getArg(2), Imm, Info))
13727+
return false;
13728+
13729+
assert(X.isVector() && A.isVector());
13730+
assert(X.getVectorLength() == A.getVectorLength());
13731+
13732+
bool IsInverse = false;
13733+
switch (E->getBuiltinCallee()) {
13734+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v16qi:
13735+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v32qi:
13736+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v64qi: {
13737+
IsInverse = true;
13738+
}
13739+
}
13740+
13741+
unsigned NumBitsInByte = 8;
13742+
unsigned NumBytesInQWord = 8;
13743+
unsigned NumBitsInQWord = 64;
13744+
unsigned NumBytes = A.getVectorLength();
13745+
unsigned NumQWords = NumBytes / NumBytesInQWord;
13746+
SmallVector<APValue, 64> Result;
13747+
Result.reserve(NumBytes);
13748+
13749+
// computing A*X + Imm
13750+
for (unsigned QWordIdx = 0; QWordIdx != NumQWords; ++QWordIdx) {
13751+
// Extract the QWords from X, A
13752+
APInt XQWord(NumBitsInQWord, 0);
13753+
APInt AQWord(NumBitsInQWord, 0);
13754+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
13755+
unsigned Idx = QWordIdx * NumBytesInQWord + ByteIdx;
13756+
APInt XByte = X.getVectorElt(Idx).getInt();
13757+
APInt AByte = A.getVectorElt(Idx).getInt();
13758+
XQWord.insertBits(XByte, ByteIdx * NumBitsInByte);
13759+
AQWord.insertBits(AByte, ByteIdx * NumBitsInByte);
13760+
}
13761+
13762+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
13763+
uint8_t XByte =
13764+
XQWord.lshr(ByteIdx * NumBitsInByte).getLoBits(8).getZExtValue();
13765+
Result.push_back(APValue(APSInt(
13766+
APInt(8, GFNIAffine(XByte, AQWord, Imm, IsInverse)), false)));
13767+
}
13768+
}
13769+
13770+
return Success(APValue(Result.data(), Result.size()), E);
13771+
}
13772+
13773+
case X86::BI__builtin_ia32_vgf2p8mulb_v16qi:
13774+
case X86::BI__builtin_ia32_vgf2p8mulb_v32qi:
13775+
case X86::BI__builtin_ia32_vgf2p8mulb_v64qi: {
13776+
APValue A, B;
13777+
if (!EvaluateAsRValue(Info, E->getArg(0), A) ||
13778+
!EvaluateAsRValue(Info, E->getArg(1), B))
13779+
return false;
13780+
13781+
assert(A.isVector() && B.isVector());
13782+
assert(A.getVectorLength() == B.getVectorLength());
13783+
13784+
unsigned NumBytes = A.getVectorLength();
13785+
SmallVector<APValue, 64> Result;
13786+
Result.reserve(NumBytes);
13787+
13788+
for (unsigned ByteIdx = 0; ByteIdx != NumBytes; ++ByteIdx) {
13789+
uint8_t AByte = A.getVectorElt(ByteIdx).getInt().getZExtValue();
13790+
uint8_t BByte = B.getVectorElt(ByteIdx).getInt().getZExtValue();
13791+
Result.push_back(APValue(
13792+
APSInt(APInt(8, GFNIMul(AByte, BByte)), /*IsUnsigned=*/false)));
13793+
}
13794+
13795+
return Success(APValue(Result.data(), Result.size()), E);
13796+
}
13797+
1371513798
case X86::BI__builtin_ia32_insertf32x4_256:
1371613799
case X86::BI__builtin_ia32_inserti32x4_256:
1371713800
case X86::BI__builtin_ia32_insertf64x2_256:
@@ -19278,6 +19361,87 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
1927819361
llvm_unreachable("unknown cast resulting in complex value");
1927919362
}
1928019363

19364+
uint8_t GFNIMultiplicativeInverse(uint8_t Byte) {
19365+
// Lookup Table for Multiplicative Inverse in GF(2^8)
19366+
const uint8_t GFInv[256] = {
19367+
0x00, 0x01, 0x8d, 0xf6, 0xcb, 0x52, 0x7b, 0xd1, 0xe8, 0x4f, 0x29, 0xc0,
19368+
0xb0, 0xe1, 0xe5, 0xc7, 0x74, 0xb4, 0xaa, 0x4b, 0x99, 0x2b, 0x60, 0x5f,
19369+
0x58, 0x3f, 0xfd, 0xcc, 0xff, 0x40, 0xee, 0xb2, 0x3a, 0x6e, 0x5a, 0xf1,
19370+
0x55, 0x4d, 0xa8, 0xc9, 0xc1, 0x0a, 0x98, 0x15, 0x30, 0x44, 0xa2, 0xc2,
19371+
0x2c, 0x45, 0x92, 0x6c, 0xf3, 0x39, 0x66, 0x42, 0xf2, 0x35, 0x20, 0x6f,
19372+
0x77, 0xbb, 0x59, 0x19, 0x1d, 0xfe, 0x37, 0x67, 0x2d, 0x31, 0xf5, 0x69,
19373+
0xa7, 0x64, 0xab, 0x13, 0x54, 0x25, 0xe9, 0x09, 0xed, 0x5c, 0x05, 0xca,
19374+
0x4c, 0x24, 0x87, 0xbf, 0x18, 0x3e, 0x22, 0xf0, 0x51, 0xec, 0x61, 0x17,
19375+
0x16, 0x5e, 0xaf, 0xd3, 0x49, 0xa6, 0x36, 0x43, 0xf4, 0x47, 0x91, 0xdf,
19376+
0x33, 0x93, 0x21, 0x3b, 0x79, 0xb7, 0x97, 0x85, 0x10, 0xb5, 0xba, 0x3c,
19377+
0xb6, 0x70, 0xd0, 0x06, 0xa1, 0xfa, 0x81, 0x82, 0x83, 0x7e, 0x7f, 0x80,
19378+
0x96, 0x73, 0xbe, 0x56, 0x9b, 0x9e, 0x95, 0xd9, 0xf7, 0x02, 0xb9, 0xa4,
19379+
0xde, 0x6a, 0x32, 0x6d, 0xd8, 0x8a, 0x84, 0x72, 0x2a, 0x14, 0x9f, 0x88,
19380+
0xf9, 0xdc, 0x89, 0x9a, 0xfb, 0x7c, 0x2e, 0xc3, 0x8f, 0xb8, 0x65, 0x48,
19381+
0x26, 0xc8, 0x12, 0x4a, 0xce, 0xe7, 0xd2, 0x62, 0x0c, 0xe0, 0x1f, 0xef,
19382+
0x11, 0x75, 0x78, 0x71, 0xa5, 0x8e, 0x76, 0x3d, 0xbd, 0xbc, 0x86, 0x57,
19383+
0x0b, 0x28, 0x2f, 0xa3, 0xda, 0xd4, 0xe4, 0x0f, 0xa9, 0x27, 0x53, 0x04,
19384+
0x1b, 0xfc, 0xac, 0xe6, 0x7a, 0x07, 0xae, 0x63, 0xc5, 0xdb, 0xe2, 0xea,
19385+
0x94, 0x8b, 0xc4, 0xd5, 0x9d, 0xf8, 0x90, 0x6b, 0xb1, 0x0d, 0xd6, 0xeb,
19386+
0xc6, 0x0e, 0xcf, 0xad, 0x08, 0x4e, 0xd7, 0xe3, 0x5d, 0x50, 0x1e, 0xb3,
19387+
0x5b, 0x23, 0x38, 0x34, 0x68, 0x46, 0x03, 0x8c, 0xdd, 0x9c, 0x7d, 0xa0,
19388+
0xcd, 0x1a, 0x41, 0x1c};
19389+
19390+
return GFInv[Byte];
19391+
}
19392+
19393+
uint8_t GFNIAffine(uint8_t XByte, APInt &AQword, APSInt Imm, bool Inverse) {
19394+
unsigned NumBitsInByte = 8;
19395+
// Computing the affine transformation
19396+
uint8_t RetByte = 0;
19397+
for (uint32_t BitIdx = 0; BitIdx != NumBitsInByte; ++BitIdx) {
19398+
uint8_t AByte =
19399+
AQword.lshr((7 - static_cast<int32_t>(BitIdx)) * NumBitsInByte)
19400+
.getLoBits(8)
19401+
.getZExtValue();
19402+
uint8_t Product;
19403+
if (Inverse) {
19404+
Product = AByte & GFNIMultiplicativeInverse(XByte);
19405+
} else {
19406+
Product = AByte & XByte;
19407+
}
19408+
uint8_t Parity = 0;
19409+
19410+
// Dot product in GF(2) uses XOR instead of addition
19411+
for (unsigned PBitIdx = 0; PBitIdx != NumBitsInByte; ++PBitIdx) {
19412+
Parity = Parity ^ ((Product >> PBitIdx) & 0x1);
19413+
}
19414+
19415+
uint8_t Temp = Imm[BitIdx] ? 1 : 0;
19416+
RetByte |= (Temp ^ Parity) << BitIdx;
19417+
}
19418+
return RetByte;
19419+
}
19420+
19421+
uint8_t GFNIMul(uint8_t AByte, uint8_t BByte) {
19422+
// Multiplying two polynomials of degree 7
19423+
// Polynomial of degree 7
19424+
// x^7 + x^6 + x^5 + x^4 + x^3 + x^2 + x + 1
19425+
uint16_t TWord = 0;
19426+
unsigned NumBitsInByte = 8;
19427+
for (unsigned BitIdx = 0; BitIdx != NumBitsInByte; ++BitIdx) {
19428+
if ((BByte >> BitIdx) & 0x1) {
19429+
TWord = TWord ^ (AByte << BitIdx);
19430+
}
19431+
}
19432+
19433+
// When multiplying two polynomials of degree 7
19434+
// results in a polynomial of degree 14
19435+
// so the result has to be reduced to 7
19436+
// Reduction polynomial is x^8 + x^4 + x^3 + x + 1 i.e. 0x11B
19437+
for (int32_t BitIdx = 14; BitIdx > 7; --BitIdx) {
19438+
if ((TWord >> BitIdx) & 0x1) {
19439+
TWord = TWord ^ (0x11B << (BitIdx - 8));
19440+
}
19441+
}
19442+
return (TWord & 0xFF);
19443+
}
19444+
1928119445
void HandleComplexComplexMul(APFloat A, APFloat B, APFloat C, APFloat D,
1928219446
APFloat &ResR, APFloat &ResI) {
1928319447
// This is an implementation of complex multiplication according to the

0 commit comments

Comments
 (0)