Skip to content

Commit cd19212

Browse files
committed
linspace+fill+type promotion tests:
1 parent a89be9c commit cd19212

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

quaddtype/tests/test_quaddtype.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,74 @@ def test_hyperbolic_functions(op, val):
591591
if float_result == 0.0:
592592
assert np.signbit(float_result) == np.signbit(
593593
quad_result), f"Zero sign mismatch for {op}({val})"
594+
595+
596+
class TestTypePomotionWithPythonAbstractTypes:
597+
"""Tests for common_dtype handling of Python abstract dtypes (PyLongDType, PyFloatDType)"""
598+
599+
def test_promotion_with_python_int(self):
600+
"""Test that Python int promotes to QuadPrecDType"""
601+
# Create array from Python int
602+
arr = np.array([1, 2, 3], dtype=QuadPrecDType)
603+
assert arr.dtype.name == "QuadPrecDType128"
604+
assert len(arr) == 3
605+
assert float(arr[0]) == 1.0
606+
assert float(arr[1]) == 2.0
607+
assert float(arr[2]) == 3.0
608+
609+
def test_promotion_with_python_float(self):
610+
"""Test that Python float promotes to QuadPrecDType"""
611+
# Create array from Python float
612+
arr = np.array([1.5, 2.7, 3.14], dtype=QuadPrecDType)
613+
assert arr.dtype.name == "QuadPrecDType128"
614+
assert len(arr) == 3
615+
np.testing.assert_allclose(float(arr[0]), 1.5, rtol=1e-15)
616+
np.testing.assert_allclose(float(arr[1]), 2.7, rtol=1e-15)
617+
np.testing.assert_allclose(float(arr[2]), 3.14, rtol=1e-15)
618+
619+
def test_result_dtype_binary_ops_with_python_types(self):
620+
"""Test that binary operations between QuadPrecDType and Python scalars return QuadPrecDType"""
621+
quad_arr = np.array([QuadPrecision("1.0"), QuadPrecision("2.0")])
622+
623+
# Addition with Python int
624+
result = quad_arr + 5
625+
assert result.dtype.name == "QuadPrecDType128"
626+
assert float(result[0]) == 6.0
627+
assert float(result[1]) == 7.0
628+
629+
# Multiplication with Python float
630+
result = quad_arr * 2.5
631+
assert result.dtype.name == "QuadPrecDType128"
632+
np.testing.assert_allclose(float(result[0]), 2.5, rtol=1e-15)
633+
np.testing.assert_allclose(float(result[1]), 5.0, rtol=1e-15)
634+
635+
def test_concatenate_with_python_types(self):
636+
"""Test concatenation handles Python numeric types correctly"""
637+
quad_arr = np.array([QuadPrecision("1.0")])
638+
# This should work if promotion is correct
639+
int_arr = np.array([2], dtype=np.int64)
640+
641+
# The result dtype should be QuadPrecDType
642+
result = np.concatenate([quad_arr, int_arr.astype(QuadPrecDType)])
643+
assert result.dtype.name == "QuadPrecDType128"
644+
assert len(result) == 2
645+
646+
647+
@pytest.mark.parametrize("func,args,expected", [
648+
# arange tests
649+
(np.arange, (0, 10), list(range(10))),
650+
(np.arange, (0, 10, 2), [0, 2, 4, 6, 8]),
651+
(np.arange, (0.0, 5.0, 0.5), [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5]),
652+
(np.arange, (10, 0, -1), [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),
653+
(np.arange, (-5, 5), list(range(-5, 5))),
654+
# linspace tests
655+
(np.linspace, (0, 10, 11), list(range(11))),
656+
(np.linspace, (0, 1, 5), [0.0, 0.25, 0.5, 0.75, 1.0]),
657+
])
658+
def test_fill_function(func, args, expected):
659+
"""Test quadprec_fill function with arange and linspace"""
660+
arr = func(*args, dtype=QuadPrecDType())
661+
assert arr.dtype.name == "QuadPrecDType128"
662+
assert len(arr) == len(expected)
663+
for i, exp_val in enumerate(expected):
664+
np.testing.assert_allclose(float(arr[i]), float(exp_val), rtol=1e-15, atol=1e-15)

0 commit comments

Comments
 (0)