Skip to content

Commit 88a8114

Browse files
committed
[X86][Clang] Allow constexpr evaluation of F16C CVTPS2PH intrinsics
1 parent 4099121 commit 88a8114

File tree

4 files changed

+206
-2
lines changed

4 files changed

+206
-2
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -716,11 +716,13 @@ let Features = "avx2", Attributes = [NoThrow, RequiredVectorWidth<128>] in {
716716
def gatherq_d : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int const *, _Vector<2, long long int>, _Vector<4, int>, _Constant char)">;
717717
}
718718

719-
let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
719+
let Features = "f16c",
720+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
720721
def vcvtps2ph : X86Builtin<"_Vector<8, short>(_Vector<4, float>, _Constant int)">;
721722
}
722723

723-
let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
724+
let Features = "f16c",
725+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
724726
def vcvtps2ph256 : X86Builtin<"_Vector<8, short>(_Vector<8, float>, _Constant int)">;
725727
}
726728

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3527,7 +3527,83 @@ static bool interp__builtin_ia32_shufbitqmb_mask(InterpState &S, CodePtr OpPC,
35273527
}
35283528

35293529
pushInteger(S, RetMask, Call->getType());
3530+
return true;
3531+
}
3532+
3533+
static bool interp__builtin_ia32_vcvtps2ph(InterpState &S, CodePtr OpPC,
3534+
const CallExpr *Call) {
3535+
// Arguments are: vector of floats, rounding immediate
3536+
assert(Call->getNumArgs() == 2);
3537+
3538+
APSInt Imm = popToAPSInt(S, Call->getArg(1));
3539+
const Pointer &Src = S.Stk.pop<Pointer>();
3540+
const Pointer &Dst = S.Stk.peek<Pointer>();
3541+
3542+
assert(Src.getFieldDesc()->isPrimitiveArray());
3543+
assert(Dst.getFieldDesc()->isPrimitiveArray());
3544+
3545+
const auto *SrcVTy = Call->getArg(0)->getType()->castAs<VectorType>();
3546+
unsigned SrcNumElems = SrcVTy->getNumElements();
3547+
const auto *DstVTy = Call->getType()->castAs<VectorType>();
3548+
unsigned DstNumElems = DstVTy->getNumElements();
3549+
3550+
const llvm::fltSemantics &HalfSem =
3551+
S.getASTContext().getFloatTypeSemantics(S.getASTContext().HalfTy);
3552+
3553+
// imm[2] == 1 means use MXCSR rounding mode.
3554+
// In that case, we can only evaluate if the conversion is exact.
3555+
int ImmVal = Imm.getZExtValue();
3556+
bool UseMXCSR = (ImmVal & 4) != 0;
3557+
3558+
llvm::RoundingMode RM;
3559+
if (!UseMXCSR) {
3560+
switch (ImmVal & 3) {
3561+
case 0: RM = llvm::RoundingMode::NearestTiesToEven; break;
3562+
case 1: RM = llvm::RoundingMode::TowardNegative; break;
3563+
case 2: RM = llvm::RoundingMode::TowardPositive; break;
3564+
case 3: RM = llvm::RoundingMode::TowardZero; break;
3565+
default: llvm_unreachable("Invalid immediate rounding mode");
3566+
}
3567+
} else {
3568+
// For MXCSR, we must check for exactness. We can use any rounding mode
3569+
// for the trial conversion since the result is the same if it's exact.
3570+
RM = llvm::RoundingMode::NearestTiesToEven;
3571+
}
35303572

3573+
QualType DstElemQT = Dst.getFieldDesc()->getElemQualType();
3574+
PrimType DstElemT = *S.getContext().classify(DstElemQT);
3575+
bool DstIsUnsigned = DstElemQT->isUnsignedIntegerOrEnumerationType();
3576+
3577+
for (unsigned I = 0; I < SrcNumElems; ++I) {
3578+
Floating SrcVal = Src.elem<Floating>(I);
3579+
APFloat DstVal = SrcVal.getAPFloat();
3580+
3581+
bool LostInfo;
3582+
APFloat::opStatus St = DstVal.convert(HalfSem, RM, &LostInfo);
3583+
3584+
if (UseMXCSR && St != APFloat::opOK) {
3585+
S.FFDiag(S.Current->getSource(OpPC), diag::note_constexpr_dynamic_rounding);
3586+
return false;
3587+
}
3588+
3589+
INT_TYPE_SWITCH_NO_BOOL(DstElemT, {
3590+
// FIX: Extract the integer value before calling 'from'.
3591+
uint64_t RawBits = DstVal.bitcastToAPInt().getZExtValue();
3592+
Dst.elem<T>(I) = T::from(RawBits);
3593+
});
3594+
}
3595+
3596+
// Zero out remaining elements if the destination has more elements
3597+
// (e.g., vcvtps2ph converting 4 floats to 8 shorts).
3598+
if (DstNumElems > SrcNumElems) {
3599+
for (unsigned I = SrcNumElems; I < DstNumElems; ++I) {
3600+
INT_TYPE_SWITCH_NO_BOOL(DstElemT, {
3601+
Dst.elem<T>(I) = T::from(0);
3602+
});
3603+
}
3604+
}
3605+
3606+
Dst.initializeAllElements();
35313607
return true;
35323608
}
35333609

@@ -4897,6 +4973,10 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
48974973
case X86::BI__builtin_ia32_vinsertf128_si256:
48984974
case X86::BI__builtin_ia32_insert128i256:
48994975
return interp__builtin_x86_insert_subvector(S, OpPC, Call, BuiltinID);
4976+
4977+
case clang::X86::BI__builtin_ia32_vcvtps2ph:
4978+
case clang::X86::BI__builtin_ia32_vcvtps2ph256:
4979+
return interp__builtin_ia32_vcvtps2ph(S, OpPC, Call);
49004980

49014981
case X86::BI__builtin_ia32_vec_ext_v4hi:
49024982
case X86::BI__builtin_ia32_vec_ext_v16qi:

clang/lib/AST/ExprConstant.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13870,6 +13870,71 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1387013870
return false;
1387113871
return Success(R, E);
1387213872
}
13873+
13874+
case clang::X86::BI__builtin_ia32_vcvtps2ph:
13875+
case clang::X86::BI__builtin_ia32_vcvtps2ph256: {
13876+
APValue SrcVec;
13877+
if (!EvaluateAsRValue(Info, E->getArg(0), SrcVec))
13878+
return false;
13879+
13880+
APSInt Imm;
13881+
if (!EvaluateInteger(E->getArg(1), Imm, Info))
13882+
return false;
13883+
13884+
assert(SrcVec.isVector());
13885+
13886+
const auto *SrcVTy = E->getArg(0)->getType()->castAs<VectorType>();
13887+
unsigned SrcNumElems = SrcVTy->getNumElements();
13888+
const auto *DstVTy = E->getType()->castAs<VectorType>();
13889+
unsigned DstNumElems = DstVTy->getNumElements();
13890+
QualType DstElemTy = DstVTy->getElementType();
13891+
13892+
const llvm::fltSemantics &HalfSem = Info.Ctx.getFloatTypeSemantics(Info.Ctx.HalfTy);
13893+
13894+
int ImmVal = Imm.getZExtValue();
13895+
bool UseMXCSR = (ImmVal & 4) != 0;
13896+
13897+
llvm::RoundingMode RM;
13898+
if (!UseMXCSR) {
13899+
switch (ImmVal & 3) {
13900+
case 0: RM = llvm::RoundingMode::NearestTiesToEven; break;
13901+
case 1: RM = llvm::RoundingMode::TowardNegative; break;
13902+
case 2: RM = llvm::RoundingMode::TowardPositive; break;
13903+
case 3: RM = llvm::RoundingMode::TowardZero; break;
13904+
default: llvm_unreachable("Invalid immediate rounding mode");
13905+
}
13906+
} else {
13907+
RM = llvm::RoundingMode::NearestTiesToEven;
13908+
}
13909+
13910+
SmallVector<APValue, 8> ResultElements;
13911+
ResultElements.reserve(DstNumElems);
13912+
13913+
for (unsigned I = 0; I < SrcNumElems; ++I) {
13914+
APFloat SrcVal = SrcVec.getVectorElt(I).getFloat();
13915+
13916+
bool LostInfo;
13917+
APFloat::opStatus St = SrcVal.convert(HalfSem, RM, &LostInfo);
13918+
13919+
if (UseMXCSR && St != APFloat::opOK) {
13920+
Info.FFDiag(E, diag::note_constexpr_dynamic_rounding);
13921+
return false;
13922+
}
13923+
13924+
APSInt DstInt(SrcVal.bitcastToAPInt(),
13925+
DstElemTy->isUnsignedIntegerOrEnumerationType());
13926+
ResultElements.push_back(APValue(DstInt));
13927+
}
13928+
13929+
if (DstNumElems > SrcNumElems) {
13930+
APSInt Zero = Info.Ctx.MakeIntValue(0, DstElemTy);
13931+
for (unsigned I = SrcNumElems; I < DstNumElems; ++I) {
13932+
ResultElements.push_back(APValue(Zero));
13933+
}
13934+
}
13935+
13936+
return Success(ResultElements, E);
13937+
}
1387313938
}
1387413939
}
1387513940

clang/test/CodeGen/X86/f16c-builtins.c

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,60 @@ __m128i test_mm256_cvtps_ph(__m256 a) {
6767
// CHECK: call <8 x i16> @llvm.x86.vcvtps2ph.256(<8 x float> %{{.*}}, i32 0)
6868
return _mm256_cvtps_ph(a, 0);
6969
}
70+
71+
// A value exactly halfway between 1.0 and the next representable FP16 number.
72+
// In binary, its significand ends in ...000, followed by a tie-bit 1.
73+
#define POS_HALFWAY (1.0f + 0.00048828125f) // 1.0 + 2^-11, a tie-breaking case
74+
75+
//
76+
// __builtin_ia32_vcvtps2ph (128-bit, 4 floats -> 8 shorts, 4 are zero-padded)
77+
//
78+
// Test values: -2.5f, 1.123f, POS_HALFWAY
79+
TEST_CONSTEXPR(match_v8hi(
80+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT),
81+
0xC100, 0x3C7E, 0x3C00, 0x0000, 0, 0, 0, 0
82+
));
83+
TEST_CONSTEXPR(match_v8hi(
84+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF),
85+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0
86+
));
87+
TEST_CONSTEXPR(match_v8hi(
88+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF),
89+
0xC100, 0x3C7E, 0x3C01, 0x0000, 0, 0, 0, 0
90+
));
91+
TEST_CONSTEXPR(match_v8hi(
92+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO),
93+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0
94+
));
95+
96+
//
97+
// __builtin_ia32_vcvtps2ph256 (256-bit, 8 floats -> 8 shorts)
98+
//
99+
// Test values: -2.5f, 1.123f, POS_HALFWAY
100+
TEST_CONSTEXPR(match_v8hi(
101+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT),
102+
0xC100, 0x3C7E, 0x3C00, 0x0000, 0xC100, 0x3C7E, 0x3C00, 0x0000
103+
));
104+
TEST_CONSTEXPR(match_v8hi(
105+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF),
106+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000
107+
));
108+
TEST_CONSTEXPR(match_v8hi(
109+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF),
110+
0xC100, 0x3C7E, 0x3C01, 0x0000, 0xC100, 0x3C7E, 0x3C01, 0x0000
111+
));
112+
TEST_CONSTEXPR(match_v8hi(
113+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO),
114+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000
115+
));
116+
117+
//
118+
// Tests for Exact Dynamic Rounding
119+
//
120+
// Test that dynamic rounding SUCCEEDS for exactly representable values.
121+
// We use _MM_FROUND_CUR_DIRECTION (value 4) to specify dynamic rounding.
122+
// Inputs: -2.5f, 0.125f, -16.0f are all exactly representable in FP16.
123+
TEST_CONSTEXPR(match_v8hi(
124+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 0.125f, -16.0f, 0.0f, -2.5f, 0.125f, -16.0f, 0.0f), _MM_FROUND_CUR_DIRECTION),
125+
0xC100, 0x3000, 0xCC00, 0x0000, 0xC100, 0x3000, 0xCC00, 0x0000
126+
));

0 commit comments

Comments
 (0)