diff --git a/clang/include/clang/Basic/BuiltinsX86.td b/clang/include/clang/Basic/BuiltinsX86.td index 41652259cf6a3..f4688901168f4 100644 --- a/clang/include/clang/Basic/BuiltinsX86.td +++ b/clang/include/clang/Basic/BuiltinsX86.td @@ -697,11 +697,13 @@ let Features = "avx2", Attributes = [NoThrow, RequiredVectorWidth<128>] in { def gatherq_d : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int const *, _Vector<2, long long int>, _Vector<4, int>, _Constant char)">; } -let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in { +let Features = "f16c", + Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in { def vcvtps2ph : X86Builtin<"_Vector<8, short>(_Vector<4, float>, _Constant int)">; } -let Features = "f16c", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in { +let Features = "f16c", + Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in { def vcvtps2ph256 : X86Builtin<"_Vector<8, short>(_Vector<8, float>, _Constant int)">; } diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp index 1eea813b8c556..614a6401c576e 100644 --- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp +++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp @@ -3002,6 +3002,91 @@ static bool interp__builtin_vec_set(InterpState &S, CodePtr OpPC, return true; } +static bool interp__builtin_ia32_vcvtps2ph(InterpState &S, CodePtr OpPC, + const CallExpr *Call) { + // Arguments are: vector of floats, rounding immediate + assert(Call->getNumArgs() == 2); + + APSInt Imm = popToAPSInt(S, Call->getArg(1)); + const Pointer &Src = S.Stk.pop(); + const Pointer &Dst = S.Stk.peek(); + + assert(Src.getFieldDesc()->isPrimitiveArray()); + assert(Dst.getFieldDesc()->isPrimitiveArray()); + + const auto *SrcVTy = Call->getArg(0)->getType()->castAs(); + unsigned SrcNumElems = SrcVTy->getNumElements(); + const auto *DstVTy = Call->getType()->castAs(); + unsigned DstNumElems = DstVTy->getNumElements(); + + const llvm::fltSemantics &HalfSem = + S.getASTContext().getFloatTypeSemantics(S.getASTContext().HalfTy); + + // imm[2] == 1 means use MXCSR rounding mode. + // In that case, we can only evaluate if the conversion is exact. + int ImmVal = Imm.getZExtValue(); + bool UseMXCSR = (ImmVal & 4) != 0; + + llvm::RoundingMode RM; + if (!UseMXCSR) { + switch (ImmVal & 3) { + case 0: + RM = llvm::RoundingMode::NearestTiesToEven; + break; + case 1: + RM = llvm::RoundingMode::TowardNegative; + break; + case 2: + RM = llvm::RoundingMode::TowardPositive; + break; + case 3: + RM = llvm::RoundingMode::TowardZero; + break; + default: + llvm_unreachable("Invalid immediate rounding mode"); + } + } else { + // For MXCSR, we must check for exactness. We can use any rounding mode + // for the trial conversion since the result is the same if it's exact. + RM = llvm::RoundingMode::NearestTiesToEven; + } + + QualType DstElemQT = Dst.getFieldDesc()->getElemQualType(); + PrimType DstElemT = *S.getContext().classify(DstElemQT); + + for (unsigned I = 0; I != SrcNumElems; ++I) { + Floating SrcVal = Src.elem(I); + APFloat DstVal = SrcVal.getAPFloat(); + + bool LostInfo; + APFloat::opStatus St = DstVal.convert(HalfSem, RM, &LostInfo); + + if (UseMXCSR && St != APFloat::opOK) { + S.FFDiag(S.Current->getSource(OpPC), + diag::note_constexpr_dynamic_rounding); + return false; + } + + INT_TYPE_SWITCH_NO_BOOL(DstElemT, { + // Convert the destination value's bit pattern to an unsigned integer, + // then reconstruct the element using the target type's 'from' method. + uint64_t RawBits = DstVal.bitcastToAPInt().getZExtValue(); + Dst.elem(I) = T::from(RawBits); + }); + } + + // Zero out remaining elements if the destination has more elements + // (e.g., vcvtps2ph converting 4 floats to 8 shorts). + if (DstNumElems > SrcNumElems) { + for (unsigned I = SrcNumElems; I != DstNumElems; ++I) { + INT_TYPE_SWITCH_NO_BOOL(DstElemT, { Dst.elem(I) = T::from(0); }); + } + } + + Dst.initializeAllElements(); + return true; +} + bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, uint32_t BuiltinID) { if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID)) @@ -3845,6 +3930,10 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case X86::BI__builtin_ia32_insert128i256: return interp__builtin_x86_insert_subvector(S, OpPC, Call, BuiltinID); + case clang::X86::BI__builtin_ia32_vcvtps2ph: + case clang::X86::BI__builtin_ia32_vcvtps2ph256: + return interp__builtin_ia32_vcvtps2ph(S, OpPC, Call); + case X86::BI__builtin_ia32_vec_ext_v4hi: case X86::BI__builtin_ia32_vec_ext_v16qi: case X86::BI__builtin_ia32_vec_ext_v8hi: diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 618e1636e9e53..7797c95d4f65c 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -12442,6 +12442,79 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) { return Success(APValue(Elems.data(), NumElems), E); } + + case clang::X86::BI__builtin_ia32_vcvtps2ph: + case clang::X86::BI__builtin_ia32_vcvtps2ph256: { + APValue SrcVec; + if (!EvaluateAsRValue(Info, E->getArg(0), SrcVec)) + return false; + + APSInt Imm; + if (!EvaluateInteger(E->getArg(1), Imm, Info)) + return false; + + const auto *SrcVTy = E->getArg(0)->getType()->castAs(); + unsigned SrcNumElems = SrcVTy->getNumElements(); + const auto *DstVTy = E->getType()->castAs(); + unsigned DstNumElems = DstVTy->getNumElements(); + QualType DstElemTy = DstVTy->getElementType(); + + const llvm::fltSemantics &HalfSem = + Info.Ctx.getFloatTypeSemantics(Info.Ctx.HalfTy); + + int ImmVal = Imm.getZExtValue(); + bool UseMXCSR = (ImmVal & 4) != 0; + + llvm::RoundingMode RM; + if (!UseMXCSR) { + switch (ImmVal & 3) { + case 0: + RM = llvm::RoundingMode::NearestTiesToEven; + break; + case 1: + RM = llvm::RoundingMode::TowardNegative; + break; + case 2: + RM = llvm::RoundingMode::TowardPositive; + break; + case 3: + RM = llvm::RoundingMode::TowardZero; + break; + default: + llvm_unreachable("Invalid immediate rounding mode"); + } + } else { + RM = llvm::RoundingMode::NearestTiesToEven; + } + + SmallVector ResultElements; + ResultElements.reserve(DstNumElems); + + for (unsigned I = 0; I < SrcNumElems; ++I) { + APFloat SrcVal = SrcVec.getVectorElt(I).getFloat(); + + bool LostInfo; + APFloat::opStatus St = SrcVal.convert(HalfSem, RM, &LostInfo); + + if (UseMXCSR && St != APFloat::opOK) { + Info.FFDiag(E, diag::note_constexpr_dynamic_rounding); + return false; + } + + APSInt DstInt(SrcVal.bitcastToAPInt(), + DstElemTy->isUnsignedIntegerOrEnumerationType()); + ResultElements.push_back(APValue(DstInt)); + } + + if (DstNumElems > SrcNumElems) { + APSInt Zero = Info.Ctx.MakeIntValue(0, DstElemTy); + for (unsigned I = SrcNumElems; I < DstNumElems; ++I) { + ResultElements.push_back(APValue(Zero)); + } + } + + return Success(ResultElements, E); + } } } diff --git a/clang/test/CodeGen/X86/f16c-builtins.c b/clang/test/CodeGen/X86/f16c-builtins.c index c08ef76d56981..47ff06b270541 100755 --- a/clang/test/CodeGen/X86/f16c-builtins.c +++ b/clang/test/CodeGen/X86/f16c-builtins.c @@ -46,6 +46,37 @@ __m128 test_mm_cvtph_ps(__m128i a) { return _mm_cvtph_ps(a); } +__m128i test_mm_cvtps_ph(__m128 a) { + // CHECK-LABEL: test_mm_cvtps_ph + // CHECK: call <8 x i16> @llvm.x86.vcvtps2ph.128(<4 x float> %{{.*}}, i32 0) + return _mm_cvtps_ph(a, 0); +} + +// A value exactly halfway between 1.0 and the next representable FP16 number. +// In binary, its significand ends in ...000, followed by a tie-bit 1. +#define POS_HALFWAY (1.0f + 0.00048828125f) // 1.0 + 2^-11, a tie-breaking case + +// +// _mm_cvtps_ph (128-bit, 4 floats -> 8 shorts, 4 are zero-padded) +// +// Test values: -2.5f, 1.123f, POS_HALFWAY +TEST_CONSTEXPR(match_v8hi( + _mm_cvtps_ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT), + 0xC100, 0x3C7E, 0x3C00, 0x0000, 0, 0, 0, 0 +)); +TEST_CONSTEXPR(match_v8hi( + _mm_cvtps_ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF), + 0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0 +)); +TEST_CONSTEXPR(match_v8hi( + _mm_cvtps_ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF), + 0xC100, 0x3C7E, 0x3C01, 0x0000, 0, 0, 0, 0 +)); +TEST_CONSTEXPR(match_v8hi( + _mm_cvtps_ph(_mm_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO), + 0xC100, 0x3C7D, 0x3C00, 0x0000, 0, 0, 0, 0 +)); + __m256 test_mm256_cvtph_ps(__m128i a) { // CHECK-LABEL: test_mm256_cvtph_ps // CHECK: fpext <8 x half> %{{.*}} to <8 x float> @@ -56,14 +87,40 @@ TEST_CONSTEXPR(match_m256( 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 0.5f, -2.0f, 0.0f )); -__m128i test_mm_cvtps_ph(__m128 a) { - // CHECK-LABEL: test_mm_cvtps_ph - // CHECK: call <8 x i16> @llvm.x86.vcvtps2ph.128(<4 x float> %{{.*}}, i32 0) - return _mm_cvtps_ph(a, 0); -} +// +// _mm256_cvtps_ph (256-bit, 8 floats -> 8 shorts) +// +// Test values: -2.5f, 1.123f, POS_HALFWAY +TEST_CONSTEXPR(match_v8hi( + _mm256_cvtps_ph(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEAREST_INT), + 0xC100, 0x3C7E, 0x3C00, 0x0000, 0xC100, 0x3C7E, 0x3C00, 0x0000 +)); +TEST_CONSTEXPR(match_v8hi( + _mm256_cvtps_ph(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_NEG_INF), + 0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000 +)); +TEST_CONSTEXPR(match_v8hi( + _mm256_cvtps_ph(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_POS_INF), + 0xC100, 0x3C7E, 0x3C01, 0x0000, 0xC100, 0x3C7E, 0x3C01, 0x0000 +)); +TEST_CONSTEXPR(match_v8hi( + _mm256_cvtps_ph(_mm256_setr_ps(-2.5f, 1.123f, POS_HALFWAY, 0.0f, -2.5f, 1.123f, POS_HALFWAY, 0.0f), _MM_FROUND_TO_ZERO), + 0xC100, 0x3C7D, 0x3C00, 0x0000, 0xC100, 0x3C7D, 0x3C00, 0x0000 +)); + +// +// Tests for Exact Dynamic Rounding +// +// Test that dynamic rounding SUCCEEDS for exactly representable values. +// We use _MM_FROUND_CUR_DIRECTION (value 4) to specify dynamic rounding. +// Inputs: -2.5f, 0.125f, -16.0f are all exactly representable in FP16. +TEST_CONSTEXPR(match_v8hi( + __builtin_ia32_vcvtps2ph256(_mm256_setr_ps(-2.5f, 0.125f, -16.0f, 0.0f, -2.5f, 0.125f, -16.0f, 0.0f), _MM_FROUND_CUR_DIRECTION), + 0xC100, 0x3000, 0xCC00, 0x0000, 0xC100, 0x3000, 0xCC00, 0x0000 +)); __m128i test_mm256_cvtps_ph(__m256 a) { // CHECK-LABEL: test_mm256_cvtps_ph // CHECK: call <8 x i16> @llvm.x86.vcvtps2ph.256(<8 x float> %{{.*}}, i32 0) return _mm256_cvtps_ph(a, 0); -} +} \ No newline at end of file