Skip to content

Commit 7757531

Browse files
committed
[X86][Clang] Allow constexpr evaluation of F16C CVTPS2PH intrinsics
1 parent b36e762 commit 7757531

File tree

4 files changed

+207
-2
lines changed

4 files changed

+207
-2
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -697,11 +697,13 @@ let Features = "avx2", Attributes = [NoThrow, RequiredVectorWidth<128>] in {
697697
def gatherq_d : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int const *, _Vector<2, long long int>, _Vector<4, int>, _Constant char)">;
698698
}
699699

700-
let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
700+
let Features = "f16c",
701+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
701702
def vcvtps2ph : X86Builtin<"_Vector<8, short>(_Vector<4, float>, _Constant int)">;
702703
}
703704

704-
let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
705+
let Features = "f16c",
706+
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
705707
def vcvtps2ph256 : X86Builtin<"_Vector<8, short>(_Vector<8, float>, _Constant int)">;
706708
}
707709

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3002,6 +3002,83 @@ static bool interp__builtin_vec_set(InterpState &S, CodePtr OpPC,
30023002
return true;
30033003
}
30043004

3005+
static bool interp__builtin_ia32_vcvtps2ph(InterpState &S, CodePtr OpPC,
3006+
const CallExpr *Call) {
3007+
// Arguments are: vector of floats, rounding immediate
3008+
assert(Call->getNumArgs() == 2);
3009+
3010+
APSInt Imm = popToAPSInt(S, Call->getArg(1));
3011+
const Pointer &Src = S.Stk.pop<Pointer>();
3012+
const Pointer &Dst = S.Stk.peek<Pointer>();
3013+
3014+
assert(Src.getFieldDesc()->isPrimitiveArray());
3015+
assert(Dst.getFieldDesc()->isPrimitiveArray());
3016+
3017+
const auto *SrcVTy = Call->getArg(0)->getType()->castAs<VectorType>();
3018+
unsigned SrcNumElems = SrcVTy->getNumElements();
3019+
const auto *DstVTy = Call->getType()->castAs<VectorType>();
3020+
unsigned DstNumElems = DstVTy->getNumElements();
3021+
3022+
const llvm::fltSemantics &HalfSem =
3023+
S.getASTContext().getFloatTypeSemantics(S.getASTContext().HalfTy);
3024+
3025+
// imm[2] == 1 means use MXCSR rounding mode.
3026+
// In that case, we can only evaluate if the conversion is exact.
3027+
int ImmVal = Imm.getZExtValue();
3028+
bool UseMXCSR = (ImmVal & 4) != 0;
3029+
3030+
llvm::RoundingMode RM;
3031+
if (!UseMXCSR) {
3032+
switch (ImmVal & 3) {
3033+
case 0: RM = llvm::RoundingMode::NearestTiesToEven; break;
3034+
case 1: RM = llvm::RoundingMode::TowardNegative; break;
3035+
case 2: RM = llvm::RoundingMode::TowardPositive; break;
3036+
case 3: RM = llvm::RoundingMode::TowardZero; break;
3037+
default: llvm_unreachable("Invalid immediate rounding mode");
3038+
}
3039+
} else {
3040+
// For MXCSR, we must check for exactness. We can use any rounding mode
3041+
// for the trial conversion since the result is the same if it's exact.
3042+
RM = llvm::RoundingMode::NearestTiesToEven;
3043+
}
3044+
3045+
QualType DstElemQT = Dst.getFieldDesc()->getElemQualType();
3046+
PrimType DstElemT = *S.getContext().classify(DstElemQT);
3047+
bool DstIsUnsigned = DstElemQT->isUnsignedIntegerOrEnumerationType();
3048+
3049+
for (unsigned I = 0; I < SrcNumElems; ++I) {
3050+
Floating SrcVal = Src.elem<Floating>(I);
3051+
APFloat DstVal = SrcVal.getAPFloat();
3052+
3053+
bool LostInfo;
3054+
APFloat::opStatus St = DstVal.convert(HalfSem, RM, &LostInfo);
3055+
3056+
if (UseMXCSR && St != APFloat::opOK) {
3057+
S.FFDiag(S.Current->getSource(OpPC), diag::note_constexpr_dynamic_rounding);
3058+
return false;
3059+
}
3060+
3061+
INT_TYPE_SWITCH_NO_BOOL(DstElemT, {
3062+
// FIX: Extract the integer value before calling 'from'.
3063+
uint64_t RawBits = DstVal.bitcastToAPInt().getZExtValue();
3064+
Dst.elem<T>(I) = T::from(RawBits);
3065+
});
3066+
}
3067+
3068+
// Zero out remaining elements if the destination has more elements
3069+
// (e.g., vcvtps2ph converting 4 floats to 8 shorts).
3070+
if (DstNumElems > SrcNumElems) {
3071+
for (unsigned I = SrcNumElems; I < DstNumElems; ++I) {
3072+
INT_TYPE_SWITCH_NO_BOOL(DstElemT, {
3073+
Dst.elem<T>(I) = T::from(0);
3074+
});
3075+
}
3076+
}
3077+
3078+
Dst.initializeAllElements();
3079+
return true;
3080+
}
3081+
30053082
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
30063083
uint32_t BuiltinID) {
30073084
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -3844,6 +3921,10 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
38443921
case X86::BI__builtin_ia32_vinsertf128_si256:
38453922
case X86::BI__builtin_ia32_insert128i256:
38463923
return interp__builtin_x86_insert_subvector(S, OpPC, Call, BuiltinID);
3924+
3925+
case clang::X86::BI__builtin_ia32_vcvtps2ph:
3926+
case clang::X86::BI__builtin_ia32_vcvtps2ph256:
3927+
return interp__builtin_ia32_vcvtps2ph(S, OpPC, Call);
38473928

38483929
case X86::BI__builtin_ia32_vec_ext_v4hi:
38493930
case X86::BI__builtin_ia32_vec_ext_v16qi:

clang/lib/AST/ExprConstant.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12442,6 +12442,71 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1244212442

1244312443
return Success(APValue(Elems.data(), NumElems), E);
1244412444
}
12445+
12446+
case clang::X86::BI__builtin_ia32_vcvtps2ph:
12447+
case clang::X86::BI__builtin_ia32_vcvtps2ph256: {
12448+
APValue SrcVec;
12449+
if (!EvaluateAsRValue(Info, E->getArg(0), SrcVec))
12450+
return false;
12451+
12452+
APSInt Imm;
12453+
if (!EvaluateInteger(E->getArg(1), Imm, Info))
12454+
return false;
12455+
12456+
assert(SrcVec.isVector());
12457+
12458+
const auto *SrcVTy = E->getArg(0)->getType()->castAs<VectorType>();
12459+
unsigned SrcNumElems = SrcVTy->getNumElements();
12460+
const auto *DstVTy = E->getType()->castAs<VectorType>();
12461+
unsigned DstNumElems = DstVTy->getNumElements();
12462+
QualType DstElemTy = DstVTy->getElementType();
12463+
12464+
const llvm::fltSemantics &HalfSem = Info.Ctx.getFloatTypeSemantics(Info.Ctx.HalfTy);
12465+
12466+
int ImmVal = Imm.getZExtValue();
12467+
bool UseMXCSR = (ImmVal & 4) != 0;
12468+
12469+
llvm::RoundingMode RM;
12470+
if (!UseMXCSR) {
12471+
switch (ImmVal & 3) {
12472+
case 0: RM = llvm::RoundingMode::NearestTiesToEven; break;
12473+
case 1: RM = llvm::RoundingMode::TowardNegative; break;
12474+
case 2: RM = llvm::RoundingMode::TowardPositive; break;
12475+
case 3: RM = llvm::RoundingMode::TowardZero; break;
12476+
default: llvm_unreachable("Invalid immediate rounding mode");
12477+
}
12478+
} else {
12479+
RM = llvm::RoundingMode::NearestTiesToEven;
12480+
}
12481+
12482+
SmallVector<APValue, 8> ResultElements;
12483+
ResultElements.reserve(DstNumElems);
12484+
12485+
for (unsigned I = 0; I < SrcNumElems; ++I) {
12486+
APFloat SrcVal = SrcVec.getVectorElt(I).getFloat();
12487+
12488+
bool LostInfo;
12489+
APFloat::opStatus St = SrcVal.convert(HalfSem, RM, &LostInfo);
12490+
12491+
if (UseMXCSR && St != APFloat::opOK) {
12492+
Info.FFDiag(E, diag::note_constexpr_dynamic_rounding);
12493+
return false;
12494+
}
12495+
12496+
APSInt DstInt(SrcVal.bitcastToAPInt(),
12497+
DstElemTy->isUnsignedIntegerOrEnumerationType());
12498+
ResultElements.push_back(APValue(DstInt));
12499+
}
12500+
12501+
if (DstNumElems > SrcNumElems) {
12502+
APSInt Zero = Info.Ctx.MakeIntValue(0, DstElemTy);
12503+
for (unsigned I = SrcNumElems; I < DstNumElems; ++I) {
12504+
ResultElements.push_back(APValue(Zero));
12505+
}
12506+
}
12507+
12508+
return Success(ResultElements, E);
12509+
}
1244512510
}
1244612511
}
1244712512

clang/test/CodeGen/X86/f16c-builtins.c

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,60 @@ __m128i test_mm256_cvtps_ph(__m256 a) {
6767
// CHECK: call <8 x i16> @llvm.x86.vcvtps2ph.256(<8 x float> %{{.*}}, i32 0)
6868
return _mm256_cvtps_ph(a, 0);
6969
}
70+
71+
// A value exactly halfway between 1.0 and the next representable FP16 number.
72+
// In binary, its significand ends in ...000, followed by a tie-bit 1.
73+
#define POS_HALFWAY (1.0f + 0.00048828125f) // 1.0 + 2^-11, a tie-breaking case
74+
75+
//
76+
// __builtin_ia32_vcvtps2ph (128-bit, 4 floats -> 8 shorts, 4 are zero-padded)
77+
//
78+
// Test values: -2.5f, 1.123f, POS_HALFWAY
79+
TEST_CONSTEXPR(match_v8hi(
80+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT),
81+
0xC100, 0x3C7E, 0x3C00, 0x0000, 0, 0, 0, 0
82+
));
83+
TEST_CONSTEXPR(match_v8hi(
84+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF),
85+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0
86+
));
87+
TEST_CONSTEXPR(match_v8hi(
88+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF),
89+
0xC100, 0x3C7E, 0x3C01, 0x0000, 0, 0, 0, 0
90+
));
91+
TEST_CONSTEXPR(match_v8hi(
92+
__builtin_ia32_vcvtps2ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO),
93+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0
94+
));
95+
96+
//
97+
// __builtin_ia32_vcvtps2ph256 (256-bit, 8 floats -> 8 shorts)
98+
//
99+
// Test values: -2.5f, 1.123f, POS_HALFWAY
100+
TEST_CONSTEXPR(match_v8hi(
101+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT),
102+
0xC100, 0x3C7E, 0x3C00, 0x0000, 0xC100, 0x3C7E, 0x3C00, 0x0000
103+
));
104+
TEST_CONSTEXPR(match_v8hi(
105+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF),
106+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000
107+
));
108+
TEST_CONSTEXPR(match_v8hi(
109+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF),
110+
0xC100, 0x3C7E, 0x3C01, 0x0000, 0xC100, 0x3C7E, 0x3C01, 0x0000
111+
));
112+
TEST_CONSTEXPR(match_v8hi(
113+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO),
114+
0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000
115+
));
116+
117+
//
118+
// Tests for Exact Dynamic Rounding
119+
//
120+
// Test that dynamic rounding SUCCEEDS for exactly representable values.
121+
// We use _MM_FROUND_CUR_DIRECTION (value 4) to specify dynamic rounding.
122+
// Inputs: -2.5f, 0.125f, -16.0f are all exactly representable in FP16.
123+
TEST_CONSTEXPR(match_v8hi(
124+
__builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 0.125f, -16.0f, 0.0f, -2.5f, 0.125f, -16.0f, 0.0f), _MM_FROUND_CUR_DIRECTION),
125+
0xC100, 0x3000, 0xCC00, 0x0000, 0xC100, 0x3000, 0xCC00, 0x0000
126+
));

0 commit comments

Comments
 (0)