diff --git a/quaddtype/numpy_quaddtype/src/ops.hpp b/quaddtype/numpy_quaddtype/src/ops.hpp index e103c8d8..fa351d21 100644 --- a/quaddtype/numpy_quaddtype/src/ops.hpp +++ b/quaddtype/numpy_quaddtype/src/ops.hpp @@ -40,7 +40,16 @@ quad_absolute(const Sleef_quad *op) static inline Sleef_quad quad_rint(const Sleef_quad *op) { - return Sleef_rintq1(*op); + Sleef_quad halfway = Sleef_addq1_u05( + Sleef_truncq1(*op), + Sleef_copysignq1(Sleef_cast_from_doubleq1(0.5), *op) + ); + + // Sleef_rintq1 does not handle some near-halfway cases correctly, so we + // manually round up or down when x is not exactly halfway + return Sleef_icmpeqq1(*op, halfway) ? Sleef_rintq1(*op) : ( + Sleef_icmpleq1(*op, halfway) ? Sleef_floorq1(*op) : Sleef_ceilq1(*op) + ); } static inline Sleef_quad diff --git a/quaddtype/tests/test_quaddtype.py b/quaddtype/tests/test_quaddtype.py index 25550db8..0a2cf3ff 100644 --- a/quaddtype/tests/test_quaddtype.py +++ b/quaddtype/tests/test_quaddtype.py @@ -209,6 +209,12 @@ def test_rounding_functions(op, val): quad_result), f"Zero sign mismatch for {op}({val})" +def test_rint_near_halfway(): + assert np.rint(QuadPrecision("7.4999999999999999")) == 7 + assert np.rint(QuadPrecision("7.49999999999999999")) == 7 + assert np.rint(QuadPrecision("7.5")) == 8 + + @pytest.mark.parametrize("op", ["exp", "exp2"]) @pytest.mark.parametrize("val", [ # Basic cases