Skip to content

Commit eee09ca

Browse files
authored
[X86][Clang] Allow constexpr evaluation of F16C CVTPS2PH intrinsics (#162295)
Fixes #160312
1 parent 7b813c3 commit eee09ca

File tree

4 files changed

+227
-2
lines changed

4 files changed

+227
-2
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,11 +710,13 @@ let Features = "avx2", Attributes = [NoThrow, RequiredVectorWidth<128>] in {
710710
def gatherq_d : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int const *, _Vector<2, long long int>, _Vector<4, int>, _Constant char)">;
711711
}
712712

713-
let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
713+
let Features = "f16c",
714+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
714715
def vcvtps2ph : X86Builtin<"_Vector<8, short>(_Vector<4, float>, _Constant int)">;
715716
}
716717

717-
let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
718+
let Features = "f16c",
719+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
718720
def vcvtps2ph256 : X86Builtin<"_Vector<8, short>(_Vector<8, float>, _Constant int)">;
719721
}
720722

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3527,7 +3527,94 @@ static bool interp__builtin_ia32_shufbitqmb_mask(InterpState &S, CodePtr OpPC,
35273527
}
35283528

35293529
pushInteger(S, RetMask, Call->getType());
3530+
return true;
3531+
}
3532+
3533+
static bool interp__builtin_ia32_vcvtps2ph(InterpState &S, CodePtr OpPC,
3534+
const CallExpr *Call) {
3535+
// Arguments are: vector of floats, rounding immediate
3536+
assert(Call->getNumArgs() == 2);
3537+
3538+
APSInt Imm = popToAPSInt(S, Call->getArg(1));
3539+
const Pointer &Src = S.Stk.pop<Pointer>();
3540+
const Pointer &Dst = S.Stk.peek<Pointer>();
3541+
3542+
assert(Src.getFieldDesc()->isPrimitiveArray());
3543+
assert(Dst.getFieldDesc()->isPrimitiveArray());
3544+
3545+
const auto *SrcVTy = Call->getArg(0)->getType()->castAs<VectorType>();
3546+
unsigned SrcNumElems = SrcVTy->getNumElements();
3547+
const auto *DstVTy = Call->getType()->castAs<VectorType>();
3548+
unsigned DstNumElems = DstVTy->getNumElements();
3549+
3550+
const llvm::fltSemantics &HalfSem =
3551+
S.getASTContext().getFloatTypeSemantics(S.getASTContext().HalfTy);
3552+
3553+
// imm[2] == 1 means use MXCSR rounding mode.
3554+
// In that case, we can only evaluate if the conversion is exact.
3555+
int ImmVal = Imm.getZExtValue();
3556+
bool UseMXCSR = (ImmVal & 4) != 0;
3557+
bool IsFPConstrained =
3558+
Call->getFPFeaturesInEffect(S.getASTContext().getLangOpts())
3559+
.isFPConstrained();
3560+
3561+
llvm::RoundingMode RM;
3562+
if (!UseMXCSR) {
3563+
switch (ImmVal & 3) {
3564+
case 0:
3565+
RM = llvm::RoundingMode::NearestTiesToEven;
3566+
break;
3567+
case 1:
3568+
RM = llvm::RoundingMode::TowardNegative;
3569+
break;
3570+
case 2:
3571+
RM = llvm::RoundingMode::TowardPositive;
3572+
break;
3573+
case 3:
3574+
RM = llvm::RoundingMode::TowardZero;
3575+
break;
3576+
default:
3577+
llvm_unreachable("Invalid immediate rounding mode");
3578+
}
3579+
} else {
3580+
// For MXCSR, we must check for exactness. We can use any rounding mode
3581+
// for the trial conversion since the result is the same if it's exact.
3582+
RM = llvm::RoundingMode::NearestTiesToEven;
3583+
}
3584+
3585+
QualType DstElemQT = Dst.getFieldDesc()->getElemQualType();
3586+
PrimType DstElemT = *S.getContext().classify(DstElemQT);
3587+
3588+
for (unsigned I = 0; I != SrcNumElems; ++I) {
3589+
Floating SrcVal = Src.elem<Floating>(I);
3590+
APFloat DstVal = SrcVal.getAPFloat();
3591+
3592+
bool LostInfo;
3593+
APFloat::opStatus St = DstVal.convert(HalfSem, RM, &LostInfo);
3594+
3595+
if (UseMXCSR && IsFPConstrained && St != APFloat::opOK) {
3596+
S.FFDiag(S.Current->getSource(OpPC),
3597+
diag::note_constexpr_dynamic_rounding);
3598+
return false;
3599+
}
35303600

3601+
INT_TYPE_SWITCH_NO_BOOL(DstElemT, {
3602+
// Convert the destination value's bit pattern to an unsigned integer,
3603+
// then reconstruct the element using the target type's 'from' method.
3604+
uint64_t RawBits = DstVal.bitcastToAPInt().getZExtValue();
3605+
Dst.elem<T>(I) = T::from(RawBits);
3606+
});
3607+
}
3608+
3609+
// Zero out remaining elements if the destination has more elements
3610+
// (e.g., vcvtps2ph converting 4 floats to 8 shorts).
3611+
if (DstNumElems > SrcNumElems) {
3612+
for (unsigned I = SrcNumElems; I != DstNumElems; ++I) {
3613+
INT_TYPE_SWITCH_NO_BOOL(DstElemT, { Dst.elem<T>(I) = T::from(0); });
3614+
}
3615+
}
3616+
3617+
Dst.initializeAllElements();
35313618
return true;
35323619
}
35333620

@@ -4956,6 +5043,10 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
49565043
case X86::BI__builtin_ia32_insert128i256:
49575044
return interp__builtin_x86_insert_subvector(S, OpPC, Call, BuiltinID);
49585045

5046+
case clang::X86::BI__builtin_ia32_vcvtps2ph:
5047+
case clang::X86::BI__builtin_ia32_vcvtps2ph256:
5048+
return interp__builtin_ia32_vcvtps2ph(S, OpPC, Call);
5049+
49595050
case X86::BI__builtin_ia32_vec_ext_v4hi:
49605051
case X86::BI__builtin_ia32_vec_ext_v16qi:
49615052
case X86::BI__builtin_ia32_vec_ext_v8hi:

clang/lib/AST/ExprConstant.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13909,6 +13909,81 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1390913909
return false;
1391013910
return Success(R, E);
1391113911
}
13912+
13913+
case clang::X86::BI__builtin_ia32_vcvtps2ph:
13914+
case clang::X86::BI__builtin_ia32_vcvtps2ph256: {
13915+
APValue SrcVec;
13916+
if (!EvaluateAsRValue(Info, E->getArg(0), SrcVec))
13917+
return false;
13918+
13919+
APSInt Imm;
13920+
if (!EvaluateInteger(E->getArg(1), Imm, Info))
13921+
return false;
13922+
13923+
const auto *SrcVTy = E->getArg(0)->getType()->castAs<VectorType>();
13924+
unsigned SrcNumElems = SrcVTy->getNumElements();
13925+
const auto *DstVTy = E->getType()->castAs<VectorType>();
13926+
unsigned DstNumElems = DstVTy->getNumElements();
13927+
QualType DstElemTy = DstVTy->getElementType();
13928+
13929+
const llvm::fltSemantics &HalfSem =
13930+
Info.Ctx.getFloatTypeSemantics(Info.Ctx.HalfTy);
13931+
13932+
int ImmVal = Imm.getZExtValue();
13933+
bool UseMXCSR = (ImmVal & 4) != 0;
13934+
bool IsFPConstrained =
13935+
E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()).isFPConstrained();
13936+
13937+
llvm::RoundingMode RM;
13938+
if (!UseMXCSR) {
13939+
switch (ImmVal & 3) {
13940+
case 0:
13941+
RM = llvm::RoundingMode::NearestTiesToEven;
13942+
break;
13943+
case 1:
13944+
RM = llvm::RoundingMode::TowardNegative;
13945+
break;
13946+
case 2:
13947+
RM = llvm::RoundingMode::TowardPositive;
13948+
break;
13949+
case 3:
13950+
RM = llvm::RoundingMode::TowardZero;
13951+
break;
13952+
default:
13953+
llvm_unreachable("Invalid immediate rounding mode");
13954+
}
13955+
} else {
13956+
RM = llvm::RoundingMode::NearestTiesToEven;
13957+
}
13958+
13959+
SmallVector<APValue, 8> ResultElements;
13960+
ResultElements.reserve(DstNumElems);
13961+
13962+
for (unsigned I = 0; I < SrcNumElems; ++I) {
13963+
APFloat SrcVal = SrcVec.getVectorElt(I).getFloat();
13964+
13965+
bool LostInfo;
13966+
APFloat::opStatus St = SrcVal.convert(HalfSem, RM, &LostInfo);
13967+
13968+
if (UseMXCSR && IsFPConstrained && St != APFloat::opOK) {
13969+
Info.FFDiag(E, diag::note_constexpr_dynamic_rounding);
13970+
return false;
13971+
}
13972+
13973+
APSInt DstInt(SrcVal.bitcastToAPInt(),
13974+
DstElemTy->isUnsignedIntegerOrEnumerationType());
13975+
ResultElements.push_back(APValue(DstInt));
13976+
}
13977+
13978+
if (DstNumElems > SrcNumElems) {
13979+
APSInt Zero = Info.Ctx.MakeIntValue(0, DstElemTy);
13980+
for (unsigned I = SrcNumElems; I < DstNumElems; ++I) {
13981+
ResultElements.push_back(APValue(Zero));
13982+
}
13983+
}
13984+
13985+
return Success(ResultElements, E);
13986+
}
1391213987
}
1391313988
}
1391413989

clang/test/CodeGen/X86/f16c-builtins.c

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,31 @@ __m128 test_mm_cvtph_ps(__m128i a) {
4646
return _mm_cvtph_ps(a);
4747
}
4848

49+
// A value exactly halfway between 1.0 and the next representable FP16 number.
50+
// In binary, its significand ends in ...000, followed by a tie-bit 1.
51+
#define POS_HALFWAY (1.0f + 0.00048828125f) // 1.0 + 2^-11, a tie-breaking case
52+
53+
//
54+
// _mm_cvtps_ph (128-bit, 4 floats -> 8 shorts, 4 are zero-padded)
55+
//
56+
// Test values: -2.5f, 1.123f, POS_HALFWAY
57+
TEST_CONSTEXPR(match_v8hi(
58+
_mm_cvtps_ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT),
59+
0xC100, 0x3C7E, 0x3C00, 0x0000, 0, 0, 0, 0
60+
));
61+
TEST_CONSTEXPR(match_v8hi(
62+
_mm_cvtps_ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF),
63+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0
64+
));
65+
TEST_CONSTEXPR(match_v8hi(
66+
_mm_cvtps_ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF),
67+
0xC100, 0x3C7E, 0x3C01, 0x0000, 0, 0, 0, 0
68+
));
69+
TEST_CONSTEXPR(match_v8hi(
70+
_mm_cvtps_ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO),
71+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0
72+
));
73+
4974
__m256 test_mm256_cvtph_ps(__m128i a) {
5075
// CHECK-LABEL: test_mm256_cvtph_ps
5176
// CHECK: fpext <8 x half> %{{.*}} to <8 x float>
@@ -56,12 +81,44 @@ TEST_CONSTEXPR(match_m256(
5681
1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 0.5f, -2.0f, 0.0f
5782
));
5883

84+
//
85+
// _mm256_cvtps_ph (256-bit, 8 floats -> 8 shorts)
86+
//
87+
// Test values: -2.5f, 1.123f, POS_HALFWAY
88+
TEST_CONSTEXPR(match_v8hi(
89+
_mm256_cvtps_ph(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT),
90+
0xC100, 0x3C7E, 0x3C00, 0x0000, 0xC100, 0x3C7E, 0x3C00, 0x0000
91+
));
92+
TEST_CONSTEXPR(match_v8hi(
93+
_mm256_cvtps_ph(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF),
94+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000
95+
));
96+
TEST_CONSTEXPR(match_v8hi(
97+
_mm256_cvtps_ph(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF),
98+
0xC100, 0x3C7E, 0x3C01, 0x0000, 0xC100, 0x3C7E, 0x3C01, 0x0000
99+
));
100+
TEST_CONSTEXPR(match_v8hi(
101+
_mm256_cvtps_ph(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO),
102+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000
103+
));
104+
59105
__m128i test_mm_cvtps_ph(__m128 a) {
60106
// CHECK-LABEL: test_mm_cvtps_ph
61107
// CHECK: call <8 x i16> @llvm.x86.vcvtps2ph.128(<4 x float> %{{.*}}, i32 0)
62108
return _mm_cvtps_ph(a, 0);
63109
}
64110

111+
//
112+
// Tests for Exact Dynamic Rounding
113+
//
114+
// Test that dynamic rounding SUCCEEDS for exactly representable values.
115+
// We use _MM_FROUND_CUR_DIRECTION (value 4) to specify dynamic rounding.
116+
// Inputs: -2.5f, 0.125f, -16.0f are all exactly representable in FP16.
117+
TEST_CONSTEXPR(match_v8hi(
118+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 0.125f, -16.0f, 0.0f, -2.5f, 0.125f, -16.0f, 0.0f), _MM_FROUND_CUR_DIRECTION),
119+
0xC100, 0x3000, 0xCC00, 0x0000, 0xC100, 0x3000, 0xCC00, 0x0000
120+
));
121+
65122
__m128i test_mm256_cvtps_ph(__m256 a) {
66123
// CHECK-LABEL: test_mm256_cvtps_ph
67124
// CHECK: call <8 x i16> @llvm.x86.vcvtps2ph.256(<8 x float> %{{.*}}, i32 0)

0 commit comments

Comments
 (0)