Skip to content

Commit 8ab3da3

Browse files
committed
[X86][Clang] Allow constexpr evaluation of F16C CVTPS2PH intrinsics
1 parent 045e09f commit 8ab3da3

File tree

4 files changed

+207
-2
lines changed

4 files changed

+207
-2
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,11 +693,13 @@ let Features = "avx2", Attributes = [NoThrow, RequiredVectorWidth<128>] in {
693693
def gatherq_d : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int const *, _Vector<2, long long int>, _Vector<4, int>, _Constant char)">;
694694
}
695695

696-
let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
696+
let Features = "f16c",
697+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
697698
def vcvtps2ph : X86Builtin<"_Vector<8, short>(_Vector<4, float>, _Constant int)">;
698699
}
699700

700-
let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
701+
let Features = "f16c",
702+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
701703
def vcvtps2ph256 : X86Builtin<"_Vector<8, short>(_Vector<8, float>, _Constant int)">;
702704
}
703705

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2878,6 +2878,83 @@ static bool interp__builtin_x86_insert_subvector(InterpState &S, CodePtr OpPC,
28782878
return true;
28792879
}
28802880

2881+
static bool interp__builtin_ia32_vcvtps2ph(InterpState &S, CodePtr OpPC,
2882+
const CallExpr *Call) {
2883+
// Arguments are: vector of floats, rounding immediate
2884+
assert(Call->getNumArgs() == 2);
2885+
2886+
APSInt Imm = popToAPSInt(S, Call->getArg(1));
2887+
const Pointer &Src = S.Stk.pop<Pointer>();
2888+
const Pointer &Dst = S.Stk.peek<Pointer>();
2889+
2890+
assert(Src.getFieldDesc()->isPrimitiveArray());
2891+
assert(Dst.getFieldDesc()->isPrimitiveArray());
2892+
2893+
const auto *SrcVTy = Call->getArg(0)->getType()->castAs<VectorType>();
2894+
unsigned SrcNumElems = SrcVTy->getNumElements();
2895+
const auto *DstVTy = Call->getType()->castAs<VectorType>();
2896+
unsigned DstNumElems = DstVTy->getNumElements();
2897+
2898+
const llvm::fltSemantics &HalfSem =
2899+
S.getASTContext().getFloatTypeSemantics(S.getASTContext().HalfTy);
2900+
2901+
// imm[2] == 1 means use MXCSR rounding mode.
2902+
// In that case, we can only evaluate if the conversion is exact.
2903+
int ImmVal = Imm.getZExtValue();
2904+
bool UseMXCSR = (ImmVal & 4) != 0;
2905+
2906+
llvm::RoundingMode RM;
2907+
if (!UseMXCSR) {
2908+
switch (ImmVal & 3) {
2909+
case 0: RM = llvm::RoundingMode::NearestTiesToEven; break;
2910+
case 1: RM = llvm::RoundingMode::TowardNegative; break;
2911+
case 2: RM = llvm::RoundingMode::TowardPositive; break;
2912+
case 3: RM = llvm::RoundingMode::TowardZero; break;
2913+
default: llvm_unreachable("Invalid immediate rounding mode");
2914+
}
2915+
} else {
2916+
// For MXCSR, we must check for exactness. We can use any rounding mode
2917+
// for the trial conversion since the result is the same if it's exact.
2918+
RM = llvm::RoundingMode::NearestTiesToEven;
2919+
}
2920+
2921+
QualType DstElemQT = Dst.getFieldDesc()->getElemQualType();
2922+
PrimType DstElemT = *S.getContext().classify(DstElemQT);
2923+
bool DstIsUnsigned = DstElemQT->isUnsignedIntegerOrEnumerationType();
2924+
2925+
for (unsigned I = 0; I < SrcNumElems; ++I) {
2926+
Floating SrcVal = Src.elem<Floating>(I);
2927+
APFloat DstVal = SrcVal.getAPFloat();
2928+
2929+
bool LostInfo;
2930+
APFloat::opStatus St = DstVal.convert(HalfSem, RM, &LostInfo);
2931+
2932+
if (UseMXCSR && St != APFloat::opOK) {
2933+
S.FFDiag(S.Current->getSource(OpPC), diag::note_constexpr_dynamic_rounding);
2934+
return false;
2935+
}
2936+
2937+
INT_TYPE_SWITCH_NO_BOOL(DstElemT, {
2938+
// FIX: Extract the integer value before calling 'from'.
2939+
uint64_t RawBits = DstVal.bitcastToAPInt().getZExtValue();
2940+
Dst.elem<T>(I) = T::from(RawBits);
2941+
});
2942+
}
2943+
2944+
// Zero out remaining elements if the destination has more elements
2945+
// (e.g., vcvtps2ph converting 4 floats to 8 shorts).
2946+
if (DstNumElems > SrcNumElems) {
2947+
for (unsigned I = SrcNumElems; I < DstNumElems; ++I) {
2948+
INT_TYPE_SWITCH_NO_BOOL(DstElemT, {
2949+
Dst.elem<T>(I) = T::from(0);
2950+
});
2951+
}
2952+
}
2953+
2954+
Dst.initializeAllElements();
2955+
return true;
2956+
}
2957+
28812958
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
28822959
uint32_t BuiltinID) {
28832960
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -3685,6 +3762,10 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
36853762
case X86::BI__builtin_ia32_vinsertf128_si256:
36863763
case X86::BI__builtin_ia32_insert128i256:
36873764
return interp__builtin_x86_insert_subvector(S, OpPC, Call, BuiltinID);
3765+
3766+
case clang::X86::BI__builtin_ia32_vcvtps2ph:
3767+
case clang::X86::BI__builtin_ia32_vcvtps2ph256:
3768+
return interp__builtin_ia32_vcvtps2ph(S, OpPC, Call);
36883769

36893770
default:
36903771
S.FFDiag(S.Current->getLocation(OpPC),

clang/lib/AST/ExprConstant.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12235,6 +12235,71 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1223512235

1223612236
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
1223712237
}
12238+
12239+
case clang::X86::BI__builtin_ia32_vcvtps2ph:
12240+
case clang::X86::BI__builtin_ia32_vcvtps2ph256: {
12241+
APValue SrcVec;
12242+
if (!EvaluateAsRValue(Info, E->getArg(0), SrcVec))
12243+
return false;
12244+
12245+
APSInt Imm;
12246+
if (!EvaluateInteger(E->getArg(1), Imm, Info))
12247+
return false;
12248+
12249+
assert(SrcVec.isVector());
12250+
12251+
const auto *SrcVTy = E->getArg(0)->getType()->castAs<VectorType>();
12252+
unsigned SrcNumElems = SrcVTy->getNumElements();
12253+
const auto *DstVTy = E->getType()->castAs<VectorType>();
12254+
unsigned DstNumElems = DstVTy->getNumElements();
12255+
QualType DstElemTy = DstVTy->getElementType();
12256+
12257+
const llvm::fltSemantics &HalfSem = Info.Ctx.getFloatTypeSemantics(Info.Ctx.HalfTy);
12258+
12259+
int ImmVal = Imm.getZExtValue();
12260+
bool UseMXCSR = (ImmVal & 4) != 0;
12261+
12262+
llvm::RoundingMode RM;
12263+
if (!UseMXCSR) {
12264+
switch (ImmVal & 3) {
12265+
case 0: RM = llvm::RoundingMode::NearestTiesToEven; break;
12266+
case 1: RM = llvm::RoundingMode::TowardNegative; break;
12267+
case 2: RM = llvm::RoundingMode::TowardPositive; break;
12268+
case 3: RM = llvm::RoundingMode::TowardZero; break;
12269+
default: llvm_unreachable("Invalid immediate rounding mode");
12270+
}
12271+
} else {
12272+
RM = llvm::RoundingMode::NearestTiesToEven;
12273+
}
12274+
12275+
SmallVector<APValue, 8> ResultElements;
12276+
ResultElements.reserve(DstNumElems);
12277+
12278+
for (unsigned I = 0; I < SrcNumElems; ++I) {
12279+
APFloat SrcVal = SrcVec.getVectorElt(I).getFloat();
12280+
12281+
bool LostInfo;
12282+
APFloat::opStatus St = SrcVal.convert(HalfSem, RM, &LostInfo);
12283+
12284+
if (UseMXCSR && St != APFloat::opOK) {
12285+
Info.FFDiag(E, diag::note_constexpr_dynamic_rounding);
12286+
return false;
12287+
}
12288+
12289+
APSInt DstInt(SrcVal.bitcastToAPInt(),
12290+
DstElemTy->isUnsignedIntegerOrEnumerationType());
12291+
ResultElements.push_back(APValue(DstInt));
12292+
}
12293+
12294+
if (DstNumElems > SrcNumElems) {
12295+
APSInt Zero = Info.Ctx.MakeIntValue(0, DstElemTy);
12296+
for (unsigned I = SrcNumElems; I < DstNumElems; ++I) {
12297+
ResultElements.push_back(APValue(Zero));
12298+
}
12299+
}
12300+
12301+
return Success(ResultElements, E);
12302+
}
1223812303
}
1223912304
}
1224012305

clang/test/CodeGen/X86/f16c-builtins.c

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,60 @@ __m128i test_mm256_cvtps_ph(__m256 a) {
6767
// CHECK: call <8 x i16> @llvm.x86.vcvtps2ph.256(<8 x float> %{{.*}}, i32 0)
6868
return _mm256_cvtps_ph(a, 0);
6969
}
70+
71+
// A value exactly halfway between 1.0 and the next representable FP16 number.
72+
// In binary, its significand ends in ...000, followed by a tie-bit 1.
73+
#define POS_HALFWAY (1.0f + 0.00048828125f) // 1.0 + 2^-11, a tie-breaking case
74+
75+
//
76+
// __builtin_ia32_vcvtps2ph (128-bit, 4 floats -> 8 shorts, 4 are zero-padded)
77+
//
78+
// Test values: -2.5f, 1.123f, POS_HALFWAY
79+
TEST_CONSTEXPR(match_v8hi(
80+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT),
81+
0xC100, 0x3C7E, 0x3C00, 0x0000, 0, 0, 0, 0
82+
));
83+
TEST_CONSTEXPR(match_v8hi(
84+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF),
85+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0
86+
));
87+
TEST_CONSTEXPR(match_v8hi(
88+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF),
89+
0xC100, 0x3C7E, 0x3C01, 0x0000, 0, 0, 0, 0
90+
));
91+
TEST_CONSTEXPR(match_v8hi(
92+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO),
93+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0
94+
));
95+
96+
//
97+
// __builtin_ia32_vcvtps2ph256 (256-bit, 8 floats -> 8 shorts)
98+
//
99+
// Test values: -2.5f, 1.123f, POS_HALFWAY
100+
TEST_CONSTEXPR(match_v8hi(
101+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT),
102+
0xC100, 0x3C7E, 0x3C00, 0x0000, 0xC100, 0x3C7E, 0x3C00, 0x0000
103+
));
104+
TEST_CONSTEXPR(match_v8hi(
105+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF),
106+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000
107+
));
108+
TEST_CONSTEXPR(match_v8hi(
109+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF),
110+
0xC100, 0x3C7E, 0x3C01, 0x0000, 0xC100, 0x3C7E, 0x3C01, 0x0000
111+
));
112+
TEST_CONSTEXPR(match_v8hi(
113+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO),
114+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000
115+
));
116+
117+
//
118+
// Tests for Exact Dynamic Rounding
119+
//
120+
// Test that dynamic rounding SUCCEEDS for exactly representable values.
121+
// We use _MM_FROUND_CUR_DIRECTION (value 4) to specify dynamic rounding.
122+
// Inputs: -2.5f, 0.125f, -16.0f are all exactly representable in FP16.
123+
TEST_CONSTEXPR(match_v8hi(
124+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 0.125f, -16.0f, 0.0f, -2.5f, 0.125f, -16.0f, 0.0f), _MM_FROUND_CUR_DIRECTION),
125+
0xC100, 0x3000, 0xCC00, 0x0000, 0xC100, 0x3000, 0xCC00, 0x0000
126+
));

0 commit comments

Comments
 (0)