Skip to content

Commit 4547c37

Browse files
chaitanyavHoney Goyal
authored andcommitted
[Clang] VectorExprEvaluator::VisitCallExpr / InterpretBuiltin - Allow GFNI intrinsics to be used in constexpr (llvm#169619)
Resolves llvm#169295
1 parent 8b888c5 commit 4547c37

File tree

6 files changed

+682
-56
lines changed

6 files changed

+682
-56
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -408,39 +408,21 @@ let Features = "avx512f,vaes", Attributes = [NoThrow, Const, RequiredVectorWidth
408408
def aesdeclast512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<8, long long int>)">;
409409
}
410410

411-
let Features = "gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
411+
let Features = "gfni", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
412412
def vgf2p8affineinvqb_v16qi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>, _Constant char)">;
413-
}
414-
415-
let Features = "avx,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
416-
def vgf2p8affineinvqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
417-
}
418-
419-
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
420-
def vgf2p8affineinvqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
421-
}
422-
423-
let Features = "gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
424413
def vgf2p8affineqb_v16qi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>, _Constant char)">;
425-
}
426-
427-
let Features = "avx,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
428-
def vgf2p8affineqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
429-
}
430-
431-
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
432-
def vgf2p8affineqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
433-
}
434-
435-
let Features = "gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
436414
def vgf2p8mulb_v16qi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>)">;
437415
}
438416

439-
let Features = "avx,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
417+
let Features = "avx,gfni", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
418+
def vgf2p8affineinvqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
419+
def vgf2p8affineqb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Constant char)">;
440420
def vgf2p8mulb_v32qi : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>)">;
441421
}
442422

443-
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
423+
let Features = "avx512f,gfni", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
424+
def vgf2p8affineinvqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
425+
def vgf2p8affineqb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Constant char)">;
444426
def vgf2p8mulb_v64qi : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>)">;
445427
}
446428

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3830,6 +3830,100 @@ static bool interp__builtin_ia32_multishiftqb(InterpState &S, CodePtr OpPC,
38303830
return true;
38313831
}
38323832

3833+
static bool interp_builtin_ia32_gfni_affine(InterpState &S, CodePtr OpPC,
3834+
const CallExpr *Call,
3835+
bool Inverse) {
3836+
assert(Call->getNumArgs() == 3);
3837+
QualType XType = Call->getArg(0)->getType();
3838+
QualType AType = Call->getArg(1)->getType();
3839+
QualType ImmType = Call->getArg(2)->getType();
3840+
if (!XType->isVectorType() || !AType->isVectorType() ||
3841+
!ImmType->isIntegerType()) {
3842+
return false;
3843+
}
3844+
3845+
Pointer X, A;
3846+
APSInt Imm = popToAPSInt(S, Call->getArg(2));
3847+
A = S.Stk.pop<Pointer>();
3848+
X = S.Stk.pop<Pointer>();
3849+
3850+
const Pointer &Dst = S.Stk.peek<Pointer>();
3851+
const auto *XVecT = XType->castAs<VectorType>();
3852+
const auto *AVecT = AType->castAs<VectorType>();
3853+
assert(XVecT->getNumElements() == AVecT->getNumElements());
3854+
unsigned NumBytesInQWord = 8;
3855+
unsigned NumBytes = AVecT->getNumElements();
3856+
unsigned NumBitsInQWord = 64;
3857+
unsigned NumQWords = NumBytes / NumBytesInQWord;
3858+
unsigned NumBitsInByte = 8;
3859+
PrimType AElemT = *S.getContext().classify(AVecT->getElementType());
3860+
3861+
// computing A*X + Imm
3862+
for (unsigned QWordIdx = 0; QWordIdx != NumQWords; ++QWordIdx) {
3863+
// Extract the QWords from X, A
3864+
APInt XQWord(NumBitsInQWord, 0);
3865+
APInt AQWord(NumBitsInQWord, 0);
3866+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
3867+
unsigned Idx = QWordIdx * NumBytesInQWord + ByteIdx;
3868+
uint8_t XByte;
3869+
uint8_t AByte;
3870+
INT_TYPE_SWITCH(AElemT, {
3871+
XByte = static_cast<uint8_t>(X.elem<T>(Idx));
3872+
AByte = static_cast<uint8_t>(A.elem<T>(Idx));
3873+
});
3874+
3875+
XQWord.insertBits(APInt(NumBitsInByte, XByte), ByteIdx * NumBitsInByte);
3876+
AQWord.insertBits(APInt(NumBitsInByte, AByte), ByteIdx * NumBitsInByte);
3877+
}
3878+
3879+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
3880+
unsigned Idx = QWordIdx * NumBytesInQWord + ByteIdx;
3881+
uint8_t XByte =
3882+
XQWord.lshr(ByteIdx * NumBitsInByte).getLoBits(8).getZExtValue();
3883+
INT_TYPE_SWITCH(AElemT, {
3884+
Dst.elem<T>(Idx) = T::from(GFNIAffine(XByte, AQWord, Imm, Inverse));
3885+
});
3886+
}
3887+
}
3888+
Dst.initializeAllElements();
3889+
return true;
3890+
}
3891+
3892+
static bool interp__builtin_ia32_gfni_mul(InterpState &S, CodePtr OpPC,
3893+
const CallExpr *Call) {
3894+
assert(Call->getNumArgs() == 2);
3895+
3896+
QualType AType = Call->getArg(0)->getType();
3897+
QualType BType = Call->getArg(1)->getType();
3898+
if (!AType->isVectorType() || !BType->isVectorType()) {
3899+
return false;
3900+
}
3901+
3902+
Pointer A, B;
3903+
B = S.Stk.pop<Pointer>();
3904+
A = S.Stk.pop<Pointer>();
3905+
3906+
const Pointer &Dst = S.Stk.peek<Pointer>();
3907+
const auto *AVecT = AType->castAs<VectorType>();
3908+
const auto *BVecT = BType->castAs<VectorType>();
3909+
assert(AVecT->getNumElements() == BVecT->getNumElements());
3910+
3911+
PrimType AElemT = *S.getContext().classify(AVecT->getElementType());
3912+
unsigned NumBytes = A.getNumElems();
3913+
3914+
for (unsigned ByteIdx = 0; ByteIdx != NumBytes; ++ByteIdx) {
3915+
uint8_t AByte, BByte;
3916+
INT_TYPE_SWITCH(AElemT, {
3917+
AByte = static_cast<uint8_t>(A.elem<T>(ByteIdx));
3918+
BByte = static_cast<uint8_t>(B.elem<T>(ByteIdx));
3919+
Dst.elem<T>(ByteIdx) = T::from(GFNIMul(AByte, BByte));
3920+
});
3921+
}
3922+
3923+
Dst.initializeAllElements();
3924+
return true;
3925+
}
3926+
38333927
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
38343928
uint32_t BuiltinID) {
38353929
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -4880,6 +4974,21 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
48804974
return std::pair<unsigned, int>{SrcIdx,
48814975
static_cast<int>(LaneOffset + Index)};
48824976
});
4977+
4978+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v16qi:
4979+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v32qi:
4980+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v64qi:
4981+
return interp_builtin_ia32_gfni_affine(S, OpPC, Call, true);
4982+
case X86::BI__builtin_ia32_vgf2p8affineqb_v16qi:
4983+
case X86::BI__builtin_ia32_vgf2p8affineqb_v32qi:
4984+
case X86::BI__builtin_ia32_vgf2p8affineqb_v64qi:
4985+
return interp_builtin_ia32_gfni_affine(S, OpPC, Call, false);
4986+
4987+
case X86::BI__builtin_ia32_vgf2p8mulb_v16qi:
4988+
case X86::BI__builtin_ia32_vgf2p8mulb_v32qi:
4989+
case X86::BI__builtin_ia32_vgf2p8mulb_v64qi:
4990+
return interp__builtin_ia32_gfni_mul(S, OpPC, Call);
4991+
48834992
case X86::BI__builtin_ia32_insertps128:
48844993
return interp__builtin_ia32_shuffle_generic(
48854994
S, OpPC, Call, [](unsigned DstIdx, unsigned Mask) {

clang/lib/AST/ExprConstShared.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
#define LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
1616

1717
#include "clang/Basic/TypeTraits.h"
18+
#include <cstdint>
1819

1920
namespace llvm {
2021
class APFloat;
22+
class APInt;
23+
class APSInt;
2124
}
2225
namespace clang {
2326
class QualType;
@@ -74,4 +77,9 @@ void HandleComplexComplexDiv(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
7477
CharUnits GetAlignOfExpr(const ASTContext &Ctx, const Expr *E,
7578
UnaryExprOrTypeTrait ExprKind);
7679

80+
uint8_t GFNIMultiplicativeInverse(uint8_t Byte);
81+
uint8_t GFNIMul(uint8_t AByte, uint8_t BByte);
82+
uint8_t GFNIAffine(uint8_t XByte, const llvm::APInt &AQword,
83+
const llvm::APSInt &Imm, bool Inverse = false);
84+
7785
#endif

clang/lib/AST/ExprConstant.cpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13869,6 +13869,89 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1386913869
return Success(R, E);
1387013870
}
1387113871

13872+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v16qi:
13873+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v32qi:
13874+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v64qi:
13875+
case X86::BI__builtin_ia32_vgf2p8affineqb_v16qi:
13876+
case X86::BI__builtin_ia32_vgf2p8affineqb_v32qi:
13877+
case X86::BI__builtin_ia32_vgf2p8affineqb_v64qi: {
13878+
13879+
APValue X, A;
13880+
APSInt Imm;
13881+
if (!EvaluateAsRValue(Info, E->getArg(0), X) ||
13882+
!EvaluateAsRValue(Info, E->getArg(1), A) ||
13883+
!EvaluateInteger(E->getArg(2), Imm, Info))
13884+
return false;
13885+
13886+
assert(X.isVector() && A.isVector());
13887+
assert(X.getVectorLength() == A.getVectorLength());
13888+
13889+
bool IsInverse = false;
13890+
switch (E->getBuiltinCallee()) {
13891+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v16qi:
13892+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v32qi:
13893+
case X86::BI__builtin_ia32_vgf2p8affineinvqb_v64qi: {
13894+
IsInverse = true;
13895+
}
13896+
}
13897+
13898+
unsigned NumBitsInByte = 8;
13899+
unsigned NumBytesInQWord = 8;
13900+
unsigned NumBitsInQWord = 64;
13901+
unsigned NumBytes = A.getVectorLength();
13902+
unsigned NumQWords = NumBytes / NumBytesInQWord;
13903+
SmallVector<APValue, 64> Result;
13904+
Result.reserve(NumBytes);
13905+
13906+
// computing A*X + Imm
13907+
for (unsigned QWordIdx = 0; QWordIdx != NumQWords; ++QWordIdx) {
13908+
// Extract the QWords from X, A
13909+
APInt XQWord(NumBitsInQWord, 0);
13910+
APInt AQWord(NumBitsInQWord, 0);
13911+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
13912+
unsigned Idx = QWordIdx * NumBytesInQWord + ByteIdx;
13913+
APInt XByte = X.getVectorElt(Idx).getInt();
13914+
APInt AByte = A.getVectorElt(Idx).getInt();
13915+
XQWord.insertBits(XByte, ByteIdx * NumBitsInByte);
13916+
AQWord.insertBits(AByte, ByteIdx * NumBitsInByte);
13917+
}
13918+
13919+
for (unsigned ByteIdx = 0; ByteIdx != NumBytesInQWord; ++ByteIdx) {
13920+
uint8_t XByte =
13921+
XQWord.lshr(ByteIdx * NumBitsInByte).getLoBits(8).getZExtValue();
13922+
Result.push_back(APValue(APSInt(
13923+
APInt(8, GFNIAffine(XByte, AQWord, Imm, IsInverse)), false)));
13924+
}
13925+
}
13926+
13927+
return Success(APValue(Result.data(), Result.size()), E);
13928+
}
13929+
13930+
case X86::BI__builtin_ia32_vgf2p8mulb_v16qi:
13931+
case X86::BI__builtin_ia32_vgf2p8mulb_v32qi:
13932+
case X86::BI__builtin_ia32_vgf2p8mulb_v64qi: {
13933+
APValue A, B;
13934+
if (!EvaluateAsRValue(Info, E->getArg(0), A) ||
13935+
!EvaluateAsRValue(Info, E->getArg(1), B))
13936+
return false;
13937+
13938+
assert(A.isVector() && B.isVector());
13939+
assert(A.getVectorLength() == B.getVectorLength());
13940+
13941+
unsigned NumBytes = A.getVectorLength();
13942+
SmallVector<APValue, 64> Result;
13943+
Result.reserve(NumBytes);
13944+
13945+
for (unsigned ByteIdx = 0; ByteIdx != NumBytes; ++ByteIdx) {
13946+
uint8_t AByte = A.getVectorElt(ByteIdx).getInt().getZExtValue();
13947+
uint8_t BByte = B.getVectorElt(ByteIdx).getInt().getZExtValue();
13948+
Result.push_back(APValue(
13949+
APSInt(APInt(8, GFNIMul(AByte, BByte)), /*IsUnsigned=*/false)));
13950+
}
13951+
13952+
return Success(APValue(Result.data(), Result.size()), E);
13953+
}
13954+
1387213955
case X86::BI__builtin_ia32_insertf32x4_256:
1387313956
case X86::BI__builtin_ia32_inserti32x4_256:
1387413957
case X86::BI__builtin_ia32_insertf64x2_256:
@@ -19478,6 +19561,88 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
1947819561
llvm_unreachable("unknown cast resulting in complex value");
1947919562
}
1948019563

19564+
uint8_t GFNIMultiplicativeInverse(uint8_t Byte) {
19565+
// Lookup Table for Multiplicative Inverse in GF(2^8)
19566+
const uint8_t GFInv[256] = {
19567+
0x00, 0x01, 0x8d, 0xf6, 0xcb, 0x52, 0x7b, 0xd1, 0xe8, 0x4f, 0x29, 0xc0,
19568+
0xb0, 0xe1, 0xe5, 0xc7, 0x74, 0xb4, 0xaa, 0x4b, 0x99, 0x2b, 0x60, 0x5f,
19569+
0x58, 0x3f, 0xfd, 0xcc, 0xff, 0x40, 0xee, 0xb2, 0x3a, 0x6e, 0x5a, 0xf1,
19570+
0x55, 0x4d, 0xa8, 0xc9, 0xc1, 0x0a, 0x98, 0x15, 0x30, 0x44, 0xa2, 0xc2,
19571+
0x2c, 0x45, 0x92, 0x6c, 0xf3, 0x39, 0x66, 0x42, 0xf2, 0x35, 0x20, 0x6f,
19572+
0x77, 0xbb, 0x59, 0x19, 0x1d, 0xfe, 0x37, 0x67, 0x2d, 0x31, 0xf5, 0x69,
19573+
0xa7, 0x64, 0xab, 0x13, 0x54, 0x25, 0xe9, 0x09, 0xed, 0x5c, 0x05, 0xca,
19574+
0x4c, 0x24, 0x87, 0xbf, 0x18, 0x3e, 0x22, 0xf0, 0x51, 0xec, 0x61, 0x17,
19575+
0x16, 0x5e, 0xaf, 0xd3, 0x49, 0xa6, 0x36, 0x43, 0xf4, 0x47, 0x91, 0xdf,
19576+
0x33, 0x93, 0x21, 0x3b, 0x79, 0xb7, 0x97, 0x85, 0x10, 0xb5, 0xba, 0x3c,
19577+
0xb6, 0x70, 0xd0, 0x06, 0xa1, 0xfa, 0x81, 0x82, 0x83, 0x7e, 0x7f, 0x80,
19578+
0x96, 0x73, 0xbe, 0x56, 0x9b, 0x9e, 0x95, 0xd9, 0xf7, 0x02, 0xb9, 0xa4,
19579+
0xde, 0x6a, 0x32, 0x6d, 0xd8, 0x8a, 0x84, 0x72, 0x2a, 0x14, 0x9f, 0x88,
19580+
0xf9, 0xdc, 0x89, 0x9a, 0xfb, 0x7c, 0x2e, 0xc3, 0x8f, 0xb8, 0x65, 0x48,
19581+
0x26, 0xc8, 0x12, 0x4a, 0xce, 0xe7, 0xd2, 0x62, 0x0c, 0xe0, 0x1f, 0xef,
19582+
0x11, 0x75, 0x78, 0x71, 0xa5, 0x8e, 0x76, 0x3d, 0xbd, 0xbc, 0x86, 0x57,
19583+
0x0b, 0x28, 0x2f, 0xa3, 0xda, 0xd4, 0xe4, 0x0f, 0xa9, 0x27, 0x53, 0x04,
19584+
0x1b, 0xfc, 0xac, 0xe6, 0x7a, 0x07, 0xae, 0x63, 0xc5, 0xdb, 0xe2, 0xea,
19585+
0x94, 0x8b, 0xc4, 0xd5, 0x9d, 0xf8, 0x90, 0x6b, 0xb1, 0x0d, 0xd6, 0xeb,
19586+
0xc6, 0x0e, 0xcf, 0xad, 0x08, 0x4e, 0xd7, 0xe3, 0x5d, 0x50, 0x1e, 0xb3,
19587+
0x5b, 0x23, 0x38, 0x34, 0x68, 0x46, 0x03, 0x8c, 0xdd, 0x9c, 0x7d, 0xa0,
19588+
0xcd, 0x1a, 0x41, 0x1c};
19589+
19590+
return GFInv[Byte];
19591+
}
19592+
19593+
uint8_t GFNIAffine(uint8_t XByte, const APInt &AQword, const APSInt &Imm,
19594+
bool Inverse) {
19595+
unsigned NumBitsInByte = 8;
19596+
// Computing the affine transformation
19597+
uint8_t RetByte = 0;
19598+
for (uint32_t BitIdx = 0; BitIdx != NumBitsInByte; ++BitIdx) {
19599+
uint8_t AByte =
19600+
AQword.lshr((7 - static_cast<int32_t>(BitIdx)) * NumBitsInByte)
19601+
.getLoBits(8)
19602+
.getZExtValue();
19603+
uint8_t Product;
19604+
if (Inverse) {
19605+
Product = AByte & GFNIMultiplicativeInverse(XByte);
19606+
} else {
19607+
Product = AByte & XByte;
19608+
}
19609+
uint8_t Parity = 0;
19610+
19611+
// Dot product in GF(2) uses XOR instead of addition
19612+
for (unsigned PBitIdx = 0; PBitIdx != NumBitsInByte; ++PBitIdx) {
19613+
Parity = Parity ^ ((Product >> PBitIdx) & 0x1);
19614+
}
19615+
19616+
uint8_t Temp = Imm[BitIdx] ? 1 : 0;
19617+
RetByte |= (Temp ^ Parity) << BitIdx;
19618+
}
19619+
return RetByte;
19620+
}
19621+
19622+
uint8_t GFNIMul(uint8_t AByte, uint8_t BByte) {
19623+
// Multiplying two polynomials of degree 7
19624+
// Polynomial of degree 7
19625+
// x^7 + x^6 + x^5 + x^4 + x^3 + x^2 + x + 1
19626+
uint16_t TWord = 0;
19627+
unsigned NumBitsInByte = 8;
19628+
for (unsigned BitIdx = 0; BitIdx != NumBitsInByte; ++BitIdx) {
19629+
if ((BByte >> BitIdx) & 0x1) {
19630+
TWord = TWord ^ (AByte << BitIdx);
19631+
}
19632+
}
19633+
19634+
// When multiplying two polynomials of degree 7
19635+
// results in a polynomial of degree 14
19636+
// so the result has to be reduced to 7
19637+
// Reduction polynomial is x^8 + x^4 + x^3 + x + 1 i.e. 0x11B
19638+
for (int32_t BitIdx = 14; BitIdx > 7; --BitIdx) {
19639+
if ((TWord >> BitIdx) & 0x1) {
19640+
TWord = TWord ^ (0x11B << (BitIdx - 8));
19641+
}
19642+
}
19643+
return (TWord & 0xFF);
19644+
}
19645+
1948119646
void HandleComplexComplexMul(APFloat A, APFloat B, APFloat C, APFloat D,
1948219647
APFloat &ResR, APFloat &ResI) {
1948319648
// This is an implementation of complex multiplication according to the

0 commit comments

Comments
 (0)