Skip to content

Commit cc5b293

Browse files
committed
[clang] Implement constant evaluation for AVX extract intrinsics (part)
1 parent 46458a4 commit cc5b293

File tree

1 file changed

+109
-2
lines changed

1 file changed

+109
-2
lines changed

clang/lib/AST/ExprConstant.cpp

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12028,7 +12028,114 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1202812028
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
1202912029
}
1203012030

12031-
case X86::BI__builtin_ia32_extract128i256:
12031+
case X86::BI__builtin_ia32_extracti32x4_256_mask: // _mm256_extracti32x4_epi32
12032+
case X86::BI__builtin_ia32_extracti32x4_mask: // _mm512_extracti32x4_epi32
12033+
case X86::BI__builtin_ia32_extracti32x8_mask: // _mm512_extracti32x8_epi32
12034+
case X86::BI__builtin_ia32_extracti64x2_256_mask: // _mm256_extracti64x2_epi64
12035+
case X86::BI__builtin_ia32_extracti64x2_512_mask: // _mm512_extracti64x2_epi64
12036+
case X86::BI__builtin_ia32_extracti64x4_mask: { // _mm512_extracti64x4_epi64
12037+
APValue SourceVec, SourceImm, SourceMerge, SourceKmask;
12038+
if (!EvaluateAsRValue(Info, E->getArg(0), SourceVec) ||
12039+
!EvaluateAsRValue(Info, E->getArg(1), SourceImm) ||
12040+
!EvaluateAsRValue(Info, E->getArg(2), SourceMerge) ||
12041+
!EvaluateAsRValue(Info, E->getArg(3), SourceKmask))
12042+
return false;
12043+
12044+
const auto *RetVT = E->getType()->castAs<VectorType>();
12045+
QualType EltTy = RetVT->getElementType();
12046+
unsigned RetLen = RetVT->getNumElements();
12047+
12048+
if (!SourceVec.isVector())
12049+
return false;
12050+
unsigned SrcLen = SourceVec.getVectorLength();
12051+
if (SrcLen % RetLen != 0)
12052+
return false;
12053+
12054+
unsigned NumLanes = SrcLen / RetLen;
12055+
unsigned idx = SourceImm.getInt().getZExtValue() & (NumLanes - 1);
12056+
12057+
// Step 2) Apply kmask (covers plain/mask/maskz):
12058+
// - plain : headers pass kmask=all-ones; merge is undef → always take Extracted.
12059+
// - mask : merge=dst; take? Extracted[i] : dst[i]
12060+
// - maskz : merge=zero; take? Extracted[i] : 0
12061+
uint64_t KmaskBits = SourceKmask.getInt().getZExtValue();
12062+
12063+
auto makeZeroInt = [&]() -> APValue {
12064+
bool Uns = EltTy->isUnsignedIntegerOrEnumerationType();
12065+
unsigned BW = Info.Ctx.getIntWidth(EltTy);
12066+
return APValue(APSInt(APInt(BW, 0), Uns));
12067+
};
12068+
12069+
SmallVector<APValue, 32> ResultElements;
12070+
ResultElements.reserve(RetLen);
12071+
for (unsigned i = 0; i < RetLen; i++) {
12072+
bool Take = (KmaskBits >> i) & 1;
12073+
if (Take) {
12074+
ResultElements.push_back(SourceVec.getVectorElt(idx * RetLen + i));
12075+
} else {
12076+
// For plain (all-ones) this path is never taken.
12077+
// For mask : merge is the original dst element.
12078+
// For maskz : headers pass zero vector as merge.
12079+
const APValue &MergeElt =
12080+
SourceMerge.isVector() ? SourceMerge.getVectorElt(i) : makeZeroInt();
12081+
ResultElements.push_back(MergeElt);
12082+
}
12083+
}
12084+
return Success(APValue(ResultElements.data(), RetLen), E);
12085+
}
12086+
12087+
case X86::BI__builtin_ia32_extractf32x4_256_mask: // _mm256_extractf32x4_ps _mm256_mask_extractf32x4_ps _mm256_maskz_extractf32x4_ps
12088+
case X86::BI__builtin_ia32_extractf32x4_mask: // _mm512_extractf32x4_ps _mm512_mask_extractf32x4_ps _mm512_maskz_extractf32x4_ps
12089+
case X86::BI__builtin_ia32_extractf32x8_mask: // _mm512_extractf32x8_ps _mm512_mask_extractf32x8_ps _mm512_maskz_extractf32x8_ps
12090+
12091+
case X86::BI__builtin_ia32_extractf64x2_256_mask: // _mm256_extractf64x2_pd _mm256_mask_extractf64x2_pd _mm256_maskz_extractf64x2_pd
12092+
case X86::BI__builtin_ia32_extractf64x2_512_mask: // _mm512_extractf64x2_pd _mm512_mask_extractf64x2_pd _mm512_maskz_extractf64x2_pd
12093+
case X86::BI__builtin_ia32_extractf64x4_mask: { // _mm512_extractf64x4_pd _mm512_mask_extractf64x4_pd _mm512_maskz_extractf64x4_pd
12094+
APValue SourceVec, SourceImm, SourceMerge, SourceKmask;
12095+
if (!EvaluateAsRValue(Info, E->getArg(0), SourceVec) ||
12096+
!EvaluateAsRValue(Info, E->getArg(1), SourceImm) ||
12097+
!EvaluateAsRValue(Info, E->getArg(2), SourceMerge) ||
12098+
!EvaluateAsRValue(Info, E->getArg(3), SourceKmask))
12099+
return false;
12100+
12101+
const auto *RetVT = E->getType()->castAs<VectorType>();
12102+
QualType EltTy = RetVT->getElementType();
12103+
unsigned RetLen = RetVT->getNumElements();
12104+
12105+
if (!SourceVec.isVector())
12106+
return false;
12107+
unsigned SrcLen = SourceVec.getVectorLength();
12108+
if (SrcLen % RetLen != 0)
12109+
return false;
12110+
12111+
unsigned NumLanes = SrcLen / RetLen;
12112+
unsigned idx = SourceImm.getInt().getZExtValue() & (NumLanes - 1);
12113+
12114+
uint64_t KmaskBits = SourceKmask.getInt().getZExtValue();
12115+
12116+
auto makeZeroFP = [&]() -> APValue {
12117+
const llvm::fltSemantics &Sem =
12118+
Info.Ctx.getFloatTypeSemantics(EltTy);
12119+
return APValue(llvm::APFloat::getZero(Sem));
12120+
};
12121+
12122+
SmallVector<APValue, 32> ResultElements;
12123+
ResultElements.reserve(RetLen);
12124+
for (unsigned i = 0; i < RetLen; i++) {
12125+
bool Take = (KmaskBits >> i) & 1;
12126+
if (Take) {
12127+
ResultElements.push_back(SourceVec.getVectorElt(idx * RetLen + i));
12128+
} else {
12129+
const APValue &MergeElt =
12130+
SourceMerge.isVector() ? SourceMerge.getVectorElt(i) : makeZeroInt();
12131+
ResultElements.push_back(MergeElt);
12132+
}
12133+
}
12134+
return Success(APValue(ResultElements.data(), RetLen), E);
12135+
}
12136+
12137+
// vector extract
12138+
case X86::BI__builtin_ia32_extract128i256: // avx2
1203212139
case X86::BI__builtin_ia32_vextractf128_pd256:
1203312140
case X86::BI__builtin_ia32_vextractf128_ps256:
1203412141
case X86::BI__builtin_ia32_vextractf128_si256: {
@@ -12044,7 +12151,7 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1204412151
if (SrcLen != RetLen * 2)
1204512152
return false;
1204612153

12047-
SmallVector<APValue, 16> ResultElements;
12154+
SmallVector<APValue, 32> ResultElements;
1204812155
ResultElements.reserve(RetLen);
1204912156

1205012157
for (unsigned i = 0; i < RetLen; i++)

0 commit comments

Comments
 (0)