Skip to content

Commit cd7f089

Browse files
committed
Add more tests for -0.0
1 parent 2926089 commit cd7f089

File tree

1 file changed

+38
-8
lines changed

1 file changed

+38
-8
lines changed

quaddtype/tests/test_quaddtype.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,28 @@ def test_basic_equality():
3535

3636

3737
@pytest.mark.parametrize("op", ["add", "sub", "mul", "truediv", "pow", "copysign"])
38-
@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
39-
def test_binary_ops(op, other):
40-
if op == "truediv" and float(other) == 0:
38+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
39+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
40+
def test_binary_ops(op, a, b):
41+
if op == "truediv" and float(b) == 0:
4142
pytest.xfail("float division by zero")
4243

4344
op_func = getattr(operator, op, None) or getattr(np, op)
44-
quad_a = QuadPrecision("12.5")
45-
quad_b = QuadPrecision(other)
46-
float_a = 12.5
47-
float_b = float(other)
45+
quad_a = QuadPrecision(a)
46+
quad_b = QuadPrecision(b)
47+
float_a = float(a)
48+
float_b = float(b)
4849

4950
quad_result = op_func(quad_a, quad_b)
5051
float_result = op_func(float_a, float_b)
5152

5253
np.testing.assert_allclose(np.float64(quad_result), float_result, atol=1e-10, rtol=0, equal_nan=True)
5354

55+
# Check sign for zero results
56+
if float_result == 0.0:
57+
assert np.signbit(float_result) == np.signbit(
58+
quad_result), f"Zero sign mismatch for {op}({a}, {b})"
59+
5460

5561
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
5662
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
@@ -93,6 +99,11 @@ def test_array_minmax(op, a, b):
9399

94100
np.testing.assert_array_equal(quad_res.astype(float), float_res)
95101

102+
# Check sign for zero results
103+
if float_result == 0.0:
104+
assert np.signbit(float_result) == np.signbit(
105+
quad_result), f"Zero sign mismatch for {op}({a}, {b})"
106+
96107

97108
@pytest.mark.parametrize("op", ["amin", "amax", "nanmin", "nanmax"])
98109
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
@@ -107,6 +118,11 @@ def test_array_aminmax(op, a, b):
107118

108119
np.testing.assert_array_equal(np.array(quad_res).astype(float), float_res)
109120

121+
# Check sign for zero results
122+
if float_result == 0.0:
123+
assert np.signbit(float_result) == np.signbit(
124+
quad_result), f"Zero sign mismatch for {op}({a}, {b})"
125+
110126

111127
@pytest.mark.parametrize("op", ["negative", "positive", "absolute", "sign", "signbit", "isfinite", "isinf", "isnan", "sqrt", "square", "reciprocal"])
112128
@pytest.mark.parametrize("val", ["3.0", "-3.0", "12.5", "100.0", "1e100", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
@@ -126,7 +142,7 @@ def test_unary_ops(op, val):
126142

127143
np.testing.assert_array_equal(np.array(quad_result).astype(float), float_result)
128144

129-
if op in ["negative", "positive", "absolute", "sign"]:
145+
if (float_result == 0.0) and (op not in ["signbit", "isfinite", "isinf", "isnan"]):
130146
assert np.signbit(float_result) == np.signbit(quad_result)
131147

132148

@@ -290,6 +306,11 @@ def test_logarithmic_functions(op, val):
290306
np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=atol,
291307
err_msg=f"Value mismatch for {op}({val})")
292308

309+
# Check sign for zero results
310+
if float_result == 0.0:
311+
assert np.signbit(float_result) == np.signbit(
312+
quad_result), f"Zero sign mismatch for {op}({a}, {b})"
313+
293314

294315
@pytest.mark.parametrize("val", [
295316
# Basic cases around -1 (critical point for log1p)
@@ -304,6 +325,8 @@ def test_logarithmic_functions(op, val):
304325
"-1.1", "-2.0", "-10.0",
305326
# Large positive values
306327
"1e10", "1e15", "1e100",
328+
# Edge cases
329+
"0.0", "-0.0",
307330
# Special values
308331
"inf", "-inf", "nan", "-nan"
309332
])
@@ -341,9 +364,16 @@ def test_log1p(val):
341364
np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=atol,
342365
err_msg=f"Value mismatch for log1p({val})")
343366

367+
# Check sign for zero results
368+
if float_result == 0.0:
369+
assert np.signbit(float_result) == np.signbit(
370+
quad_result), f"Zero sign mismatch for {op}({val})"
371+
344372
def test_inf():
345373
assert QuadPrecision("inf") > QuadPrecision("1e1000")
374+
assert np.signbit(QuadPrecision("inf")) == 0
346375
assert QuadPrecision("-inf") < QuadPrecision("-1e1000")
376+
assert np.signbit(QuadPrecision("-inf")) == 1
347377

348378

349379
def test_dtype_creation():

0 commit comments

Comments
 (0)