Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion quaddtype/numpy_quaddtype/src/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,16 @@ quad_absolute(const Sleef_quad *op)
static inline Sleef_quad
quad_rint(const Sleef_quad *op)
{
return Sleef_rintq1(*op);
Sleef_quad halfway = Sleef_addq1_u05(
Sleef_truncq1(*op),
Sleef_copysignq1(Sleef_cast_from_doubleq1(0.5), *op)
);

// Sleef_rintq1 does not handle some near-halfway cases correctly, so we
// manually round up or down when x is not exactly halfway
return Sleef_icmpeqq1(*op, halfway) ? Sleef_rintq1(*op) : (
Sleef_icmpleq1(*op, halfway) ? Sleef_floorq1(*op) : Sleef_ceilq1(*op)
);
}

static inline Sleef_quad
Expand Down
6 changes: 6 additions & 0 deletions quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ def test_rounding_functions(op, val):
quad_result), f"Zero sign mismatch for {op}({val})"


def test_rint_near_halfway():
assert np.rint(QuadPrecision("7.4999999999999999")) == 7
assert np.rint(QuadPrecision("7.49999999999999999")) == 7
assert np.rint(QuadPrecision("7.5")) == 8


@pytest.mark.parametrize("op", ["exp", "exp2"])
@pytest.mark.parametrize("val", [
# Basic cases
Expand Down
Loading