Skip to content

Commit 4d1dbab

Browse files
Merge pull request #340 from Maegereg/dasm/arb-tests
Add a much more detailed suite of tests for arb
2 parents cddbf41 + b8fde5b commit 4d1dbab

File tree

3 files changed

+365
-1
lines changed

3 files changed

+365
-1
lines changed

src/flint/test/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pyfiles = [
44
'__init__.py',
55
'__main__.py',
66
'test_all.py',
7+
'test_arb.py',
78
'test_docstrings.py',
89
]
910

src/flint/test/test_all.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import flint.flint_base.flint_base as flint_base
1313
from flint.utils.flint_exceptions import DomainError, IncompatibleContextError
1414

15+
from flint.test.test_arb import all_tests as arb_tests
16+
1517

1618
PYPY = platform.python_implementation() == "PyPy"
1719

@@ -5233,4 +5235,4 @@ def test_all_tests():
52335235
test_python_threads,
52345236

52355237
test_all_tests,
5236-
]
5238+
] + arb_tests

src/flint/test/test_arb.py

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
"""Test for python-flint's `arb` type."""
2+
3+
import math
4+
5+
from flint import arb, ctx
6+
7+
def assert_almost_equal(x, y, places=7):
8+
"""Helper method for approximate comparisons."""
9+
assert round(x-y, ndigits=places) == 0
10+
11+
def test_from_int():
12+
"""Tests instantiating `arb`s from ints."""
13+
for val in [
14+
-42 * 10**9,
15+
-42 * 10**7,
16+
-42,
17+
0,
18+
42,
19+
42 * 10**7,
20+
42 * 10**9,
21+
42 * 10**11,
22+
]:
23+
x = arb(val)
24+
man, exp = x.man_exp()
25+
assert (man * 2**exp) == val
26+
27+
def test_from_float():
28+
"""Tests instantiating `arb`s from floats."""
29+
for val in [0.0, 1.1, -1.1, 9.9 * 0.123, 99.12]:
30+
x = arb(val)
31+
man, exp = x.man_exp()
32+
assert (int(man) * 2 ** int(exp)) == val
33+
34+
def test_from_float_inf():
35+
"""Tests `arb` works with +/- inf."""
36+
posinf = arb(float("inf"))
37+
neginf = arb(float("-inf"))
38+
39+
assert not posinf.is_finite()
40+
assert not neginf.is_finite()
41+
assert float(posinf) == float("inf")
42+
assert float(neginf) == float("-inf")
43+
44+
def test_from_man_exp():
45+
"""Tests instantiating `arb`s with mantissa and exponent."""
46+
for man, exp in [(2, 30), (4, 300), (5 * 10**2, 7**8)]:
47+
x = arb(mid=(man, exp))
48+
m, e = x.man_exp()
49+
assert (m * 2**e) == (man * 2**exp)
50+
51+
def test_from_midpoint_radius():
52+
"""Tests instantiating `arb`s with midpoint and radius."""
53+
for mid, rad in [(10, 1), (10000, 5), (10, 1), (10, 1)]:
54+
mid_arb = arb(mid)
55+
rad_arb = arb(rad)
56+
x = arb(mid_arb, rad_arb)
57+
assert x.mid() == mid_arb
58+
actual_radius = float(x.rad())
59+
assert_almost_equal(actual_radius, rad)
60+
61+
def test_is_exact():
62+
"""Tests `arb.is_exact`."""
63+
for arb_val, exact in [
64+
(arb(10), True),
65+
(arb(0.01), True),
66+
(arb(-float("inf")), True),
67+
(arb(1, 0), True),
68+
(arb(1, 1), False),
69+
]:
70+
assert arb_val.is_exact() == exact
71+
72+
def test_is_finite():
73+
"""Tests `arb.is_finite`."""
74+
assert not (arb(-float("inf")).is_finite())
75+
assert not (arb(float("inf")).is_finite())
76+
assert (arb(10).is_finite())
77+
78+
def test_is_nan():
79+
"""Tests `arb.is_nan`."""
80+
assert (arb(float("nan")).is_nan())
81+
assert not (arb(0.0).is_nan())
82+
83+
def test_lower():
84+
"""Tests `arb.lower`."""
85+
with ctx.workprec(100):
86+
arb_val = arb(1, 0.5)
87+
assert_almost_equal(float(arb_val.lower()), 0.5)
88+
89+
def test_upper():
90+
"""Tests `arb.upper`."""
91+
with ctx.workprec(100):
92+
arb_val = arb(1, 0.5)
93+
assert_almost_equal(float(arb_val.upper()), 1.5)
94+
95+
def test_contains():
96+
"""`y.__contains__(x)` returns True iff every number in `x` is also in `y`."""
97+
for x, y, expected in [
98+
(
99+
arb(mid=9, rad=1),
100+
arb(mid=10, rad=2),
101+
True,
102+
),
103+
(
104+
arb(mid=10, rad=2),
105+
arb(mid=9, rad=1),
106+
False,
107+
),
108+
(arb(10), arb(mid=9, rad=1), True),
109+
(arb(10.1), arb(mid=9, rad=1), False),
110+
]:
111+
assert (x in y) == expected
112+
113+
# TODO: Re-enable this if we ever add the ability to hash arbs.
114+
# def test_hash():
115+
# """`x` and `y` hash to the same value if they have the same midpoint and radius.
116+
117+
# Args:
118+
# x: An arb.
119+
# y: An arb.
120+
# expected: Whether `x` and `y` should hash to the same value.
121+
# """
122+
# def arb_pi(prec):
123+
# """Helper to calculate arb to a given precision."""
124+
# with ctx.workprec(prec):
125+
# return arb.pi()
126+
# for x, y, expected in [
127+
# (arb(10), arb(10), True),
128+
# (arb(10), arb(11), False),
129+
# (arb(10.0), arb(10), True),
130+
# (
131+
# arb(mid=10, rad=2),
132+
# arb(mid=10, rad=2),
133+
# True,
134+
# ),
135+
# (
136+
# arb(mid=10, rad=2),
137+
# arb(mid=10, rad=3),
138+
# False,
139+
# ),
140+
# (arb_pi(100), arb_pi(100), True),
141+
# (arb_pi(100), arb_pi(1000), False),
142+
# ]:
143+
# assert (hash(x) == hash(y)) == expected
144+
145+
146+
147+
# Tests for arithmetic functions in `flint.arb`.
148+
149+
# NOTE: Correctness of an arb function `F` is specified as follows:
150+
151+
# If `f` is the corresponding real-valued arithmetic function, `F` is correct
152+
# only if, for any Arb X and any real number x in the interval X,
153+
# `f(x)` is in `F(X)`.
154+
155+
# These tests assume arb.__contains__ is correct.
156+
157+
def test_arb_sub():
158+
"""`arb.__sub__` works as expected."""
159+
arb1 = arb(2, 0.5)
160+
arb2 = arb(1, 1)
161+
with ctx.workprec(100):
162+
actual = arb1 - arb2
163+
# Smallest value in diff => 1.5 - 2 = -0.5
164+
# Largest value in diff => 2.5 - 0 = 2.5
165+
true_interval = arb(1, 1.5) # [-0.5, 2.5]
166+
assert true_interval in actual
167+
168+
def test_arb_add():
169+
"""`arb.__add__` works as expected."""
170+
arb1 = arb(2, 1)
171+
arb2 = arb(1, 1)
172+
with ctx.workprec(100):
173+
actual = arb1 + arb2
174+
true_interval = arb(3, 2) # [1, 5]
175+
assert true_interval in actual
176+
177+
def test_arb_mul():
178+
"""`arb.__mul__` works as expected."""
179+
arb1 = arb(2, 1)
180+
arb2 = arb(1, 1)
181+
with ctx.workprec(100):
182+
actual = arb1 * arb2
183+
true_interval = arb(3, 3) # [0, 6]
184+
assert true_interval in actual
185+
186+
def test_arb_div():
187+
"""`arb.__div__` works as expected."""
188+
arb1 = arb(4, 1)
189+
arb2 = arb(2, 1)
190+
with ctx.workprec(100):
191+
actual = arb1 / arb2
192+
true_interval = arb(4, 1) # [3, 5]
193+
assert true_interval in actual
194+
195+
def test_arb_log():
196+
"""`arb.log` works as expected."""
197+
midpoint = (1 + math.exp(10)) / 2
198+
arb_val = arb(midpoint, midpoint - 1) # [1, exp(10)]
199+
with ctx.workprec(100):
200+
actual = arb_val.log()
201+
true_interval = arb(5, 5) # [0,10]
202+
assert true_interval in actual
203+
204+
def test_arb_exp():
205+
"""`arb.exp` works as expected."""
206+
midpoint = math.log(9) / 2
207+
arb_val = arb(midpoint, midpoint) # [0, log(9)]
208+
with ctx.workprec(100):
209+
actual = arb_val.exp()
210+
true_interval = arb(5, 4) # [1,9]
211+
assert true_interval in actual
212+
213+
def test_arb_max():
214+
"""`arb.max` works as expected."""
215+
arb1 = arb(1.5, 0.5) # [1, 2]
216+
arb2 = arb(1, 2) # [-1, 3]
217+
with ctx.workprec(100):
218+
actual = arb1.max(arb2)
219+
true_interval = arb(2, 1) # [1, 3]
220+
assert true_interval in actual
221+
222+
def test_arb_min():
223+
"""`arb.min` works as expected."""
224+
arb1 = arb(1.5, 0.5) # [1, 2]
225+
arb2 = arb(1, 2) # [-1, 3]
226+
with ctx.workprec(100):
227+
actual = arb1.min(arb2)
228+
true_interval = arb(0.5, 1.5) # [-1, 2]
229+
assert true_interval in actual
230+
231+
def test_arb_abs():
232+
"""`arb.__abs__` works as expected."""
233+
arb_val = arb(1, 2) # [-1,3]
234+
actual = abs(arb_val)
235+
true_interval = arb(1.5, 1.5)
236+
assert true_interval in actual
237+
238+
def test_arb_neg():
239+
"""`arb.neg` works as expected."""
240+
arb_val = arb(1, 2) # [-1,3]
241+
actual = arb_val.neg(exact=True)
242+
true_interval = arb(-2, 1) # [-3,1]
243+
assert true_interval in actual
244+
245+
def test_arb_neg_dunder():
246+
"""`arb.__neg__` works as expected."""
247+
arb_val = arb(1, 2) # [-1,3]
248+
actual = -arb_val
249+
true_interval = arb(-2, 1) # [-3,1]
250+
assert true_interval in actual
251+
252+
def test_arb_sgn():
253+
"""`arb.sgn` works as expected."""
254+
arb1 = arb(1, 0.5) # [0.5,1.5]
255+
arb2 = arb(-1, 0.5) # [-1.5,-0.5]
256+
arb3 = arb(1, 2) # [-1,3]
257+
assert_almost_equal(float(arb1.sgn()), 1)
258+
assert_almost_equal(float(arb2.sgn()), -1)
259+
# arb3 contains both positive and negative numbers
260+
# So, arb_sgn returns [0, 1]
261+
assert_almost_equal(float(arb3.sgn().mid()), 0)
262+
assert_almost_equal(float(arb3.sgn().rad()), 1)
263+
264+
def test_arb_erfinv():
265+
"""`arb.erfinv` works as expected."""
266+
midpoint = (math.erf(1 / 8) + math.erf(1 / 16)) / 2
267+
radius = midpoint - math.erf(1 / 16)
268+
arb_val = arb(midpoint, radius)
269+
with ctx.workprec(100):
270+
actual = arb_val.erfinv()
271+
true_interval = arb(3 / 32, 1 / 32) # [1/16, 1/8]
272+
assert true_interval in actual
273+
274+
def test_arb_erf():
275+
"""`arb.erf` works as expected."""
276+
arb_val = arb(2, 1)
277+
with ctx.workprec(100):
278+
actual = arb_val.erf()
279+
true_interval = arb(
280+
(math.erf(1) + math.erf(3)) / 2,
281+
(math.erf(1) + math.erf(3)) / 2 - math.erf(1)
282+
)
283+
assert true_interval in actual
284+
285+
def test_arb_erfc():
286+
"""`arb.erfc` works as expected."""
287+
arb_val = arb(2, 1)
288+
with ctx.workprec(100):
289+
actual = arb_val.erfc()
290+
true_interval = arb(
291+
(math.erfc(1) + math.erfc(3)) / 2,
292+
(math.erfc(1) + math.erfc(3)) / 2 - math.erfc(3)
293+
)
294+
assert true_interval in actual
295+
296+
def test_arb_const_pi():
297+
"""`arb.pi` works as expected."""
298+
with ctx.workprec(100):
299+
actual = arb.pi()
300+
interval_around_pi = arb(math.pi, 1e-10)
301+
assert actual in interval_around_pi
302+
303+
def test_arb_union():
304+
"""`arb.union` works as expected."""
305+
arb1 = arb(1, 0.5) # [0.5,1.5]
306+
arb2 = arb(3, 0.5) # [2.5,3.5]
307+
with ctx.workprec(100):
308+
actual = arb1.union(arb2)
309+
true_interval = arb(2, 1.5) # [0.5, 3.5]
310+
assert true_interval in actual
311+
312+
def test_arb_sum():
313+
"""`arb.__sum__` works as expected."""
314+
arb1 = arb(1, 0.5) # [0.5,1.5]
315+
arb2 = arb(2, 0.5) # [1.5,2.5]
316+
arb3 = arb(3, 0.5) # [2.5,3.5]
317+
with ctx.workprec(100):
318+
actual = arb1 + arb2 + arb3
319+
true_interval = arb(6, 1.5) # [4.5, 7.5]
320+
assert true_interval in actual
321+
322+
def test_no_tests_missing():
323+
"""Make sure all arb tests are included in all_tests."""
324+
test_funcs = {f for name, f in globals().items() if name.startswith("test_")}
325+
untested = test_funcs - set(all_tests)
326+
assert not untested, f"Untested functions: {untested}"
327+
328+
all_tests = [
329+
test_no_tests_missing,
330+
test_from_int,
331+
test_from_float,
332+
test_from_float_inf,
333+
test_from_man_exp,
334+
test_from_midpoint_radius,
335+
test_is_exact,
336+
test_is_finite,
337+
test_is_nan,
338+
test_lower,
339+
test_upper,
340+
test_contains,
341+
# TODO: Re-enable this if we ever add the ability to hash arbs.
342+
# test_hash,
343+
test_arb_sub,
344+
test_arb_add,
345+
test_arb_mul,
346+
test_arb_div,
347+
test_arb_log,
348+
test_arb_exp,
349+
test_arb_max,
350+
test_arb_min,
351+
test_arb_abs,
352+
test_arb_neg,
353+
test_arb_neg_dunder,
354+
test_arb_sgn,
355+
test_arb_erfinv,
356+
test_arb_erf,
357+
test_arb_erfc,
358+
test_arb_const_pi,
359+
test_arb_union,
360+
test_arb_sum,
361+
]

0 commit comments

Comments
 (0)