Skip to content

Commit 2aaee72

Browse files
committed
Add a much more detailed suite of tests for arb.
1 parent cddbf41 commit 2aaee72

File tree

4 files changed

+355
-1
lines changed

4 files changed

+355
-1
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ pytest-cov
99
sphinx
1010
sphinx-rtd-theme
1111
furo
12+
scipy

src/flint/test/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pyfiles = [
44
'__init__.py',
55
'__main__.py',
66
'test_all.py',
7+
'test_arb.py',
78
'test_docstrings.py',
89
]
910

src/flint/test/test_all.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import flint.flint_base.flint_base as flint_base
1313
from flint.utils.flint_exceptions import DomainError, IncompatibleContextError
1414

15+
from flint.test.test_arb import all_tests as arb_tests
16+
1517

1618
PYPY = platform.python_implementation() == "PyPy"
1719

@@ -5233,4 +5235,4 @@ def test_all_tests():
52335235
test_python_threads,
52345236

52355237
test_all_tests,
5236-
]
5238+
] + arb_tests

src/flint/test/test_arb.py

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
"""Test for python-flint's `arb` type."""
2+
3+
import math
4+
5+
from flint import arb, ctx
6+
7+
from scipy.special import erf, erfc
8+
9+
def assert_almost_equal(x, y, places=7):
10+
"""Helper method for approximate comparisons."""
11+
assert round(x-y, ndigits=places) == 0
12+
13+
def test_from_int():
14+
"""Tests instantiating `arb`s from ints."""
15+
for val in [
16+
-42 * 10**9,
17+
-42 * 10**7,
18+
-42,
19+
0,
20+
42,
21+
42 * 10**7,
22+
42 * 10**9,
23+
42 * 10**11,
24+
]:
25+
x = arb(val)
26+
man, exp = x.man_exp()
27+
assert (man * 2**exp) == val
28+
29+
def test_from_float():
30+
"""Tests instantiating `arb`s from floats."""
31+
for val in [0.0, 1.1, -1.1, 9.9 * 0.123, 99.12]:
32+
x = arb(val)
33+
man, exp = x.man_exp()
34+
assert (int(man) * 2 ** int(exp)) == val
35+
36+
def test_from_float_inf():
37+
"""Tests `arb` works with +/- inf."""
38+
posinf = arb(float("inf"))
39+
neginf = arb(float("-inf"))
40+
41+
assert not posinf.is_finite()
42+
assert not neginf.is_finite()
43+
assert float(posinf) == float("inf")
44+
assert float(neginf) == float("-inf")
45+
46+
def test_from_man_exp():
47+
"""Tests instantiating `arb`s with mantissa and exponent."""
48+
for man, exp in [(2, 30), (4, 300), (5 * 10**2, 7**8)]:
49+
x = arb(mid=(man, exp))
50+
m, e = x.man_exp()
51+
assert (m * 2**e) == (man * 2**exp)
52+
53+
def test_from_midpoint_radius():
54+
"""Tests instantiating `arb`s with midpoint and radius."""
55+
for mid, rad in [(10, 1), (10000, 5), (10, 1), (10, 1)]:
56+
mid_arb = arb(mid)
57+
rad_arb = arb(rad)
58+
x = arb(mid_arb, rad_arb)
59+
assert x.mid() == mid_arb
60+
actual_radius = float(x.rad())
61+
assert_almost_equal(actual_radius, rad)
62+
63+
def test_is_exact():
64+
"""Tests `arb.is_exact`."""
65+
for arb_val, exact in [
66+
(arb(10), True),
67+
(arb(0.01), True),
68+
(arb(-float("inf")), True),
69+
(arb(1, 0), True),
70+
(arb(1, 1), False),
71+
]:
72+
assert arb_val.is_exact() == exact
73+
74+
def test_is_finite():
75+
"""Tests `arb.is_finite`."""
76+
assert not (arb(-float("inf")).is_finite())
77+
assert not (arb(float("inf")).is_finite())
78+
assert (arb(10).is_finite())
79+
80+
def test_is_nan():
81+
"""Tests `arb.is_nan`."""
82+
assert (arb(float("nan")).is_nan())
83+
assert not (arb(0.0).is_nan())
84+
85+
def test_lower():
86+
"""Tests `arb.lower`."""
87+
with ctx.workprec(100):
88+
arb_val = arb(1, 0.5)
89+
assert_almost_equal(float(arb_val.lower()), 0.5)
90+
91+
def test_upper():
92+
"""Tests `arb.upper`."""
93+
with ctx.workprec(100):
94+
arb_val = arb(1, 0.5)
95+
assert_almost_equal(float(arb_val.upper()), 1.5)
96+
97+
def test_contains():
98+
"""`y.__contains__(x)` returns True iff every number in `x` is also in `y`."""
99+
for x, y, expected in [
100+
(
101+
arb(mid=9, rad=1),
102+
arb(mid=10, rad=2),
103+
True,
104+
),
105+
(
106+
arb(mid=10, rad=2),
107+
arb(mid=9, rad=1),
108+
False,
109+
),
110+
(arb(10), arb(mid=9, rad=1), True),
111+
(arb(10.1), arb(mid=9, rad=1), False),
112+
]:
113+
assert (x in y) == expected
114+
115+
# TODO: Re-enable this if we ever add the ability to hash arbs.
116+
# def test_hash():
117+
# """`x` and `y` hash to the same value if they have the same midpoint and radius.
118+
119+
# Args:
120+
# x: An arb.
121+
# y: An arb.
122+
# expected: Whether `x` and `y` should hash to the same value.
123+
# """
124+
# def arb_pi(prec):
125+
# """Helper to calculate arb to a given precision."""
126+
# with ctx.workprec(prec):
127+
# return arb.pi()
128+
# for x, y, expected in [
129+
# (arb(10), arb(10), True),
130+
# (arb(10), arb(11), False),
131+
# (arb(10.0), arb(10), True),
132+
# (
133+
# arb(mid=10, rad=2),
134+
# arb(mid=10, rad=2),
135+
# True,
136+
# ),
137+
# (
138+
# arb(mid=10, rad=2),
139+
# arb(mid=10, rad=3),
140+
# False,
141+
# ),
142+
# (arb_pi(100), arb_pi(100), True),
143+
# (arb_pi(100), arb_pi(1000), False),
144+
# ]:
145+
# assert (hash(x) == hash(y)) == expected
146+
147+
148+
149+
# Tests for arithmetic functions in `flint.arb`.
150+
151+
# NOTE: Correctness of an arb function `F` is specified as follows:
152+
153+
# If `f` is the corresponding real-valued arithmetic function, `F` is correct
154+
# only if, for any Arb X and any real number x in the interval X,
155+
# `f(x)` is in `F(X)`.
156+
157+
# These tests assume arb.__contains__ is correct.
158+
159+
def test_arb_sub():
160+
"""`arb.__sub__` works as expected."""
161+
arb1 = arb(2, 0.5)
162+
arb2 = arb(1, 1)
163+
with ctx.workprec(100):
164+
actual = arb1 - arb2
165+
# Smallest value in diff => 1.5 - 2 = -0.5
166+
# Largest value in diff => 2.5 - 0 = 2.5
167+
true_interval = arb(1, 1.5) # [-0.5, 2.5]
168+
assert true_interval in actual
169+
170+
def test_arb_add():
171+
"""`arb.__add__` works as expected."""
172+
arb1 = arb(2, 1)
173+
arb2 = arb(1, 1)
174+
with ctx.workprec(100):
175+
actual = arb1 + arb2
176+
true_interval = arb(3, 2) # [1, 5]
177+
assert true_interval in actual
178+
179+
def test_arb_mul():
180+
"""`arb.__mul__` works as expected."""
181+
arb1 = arb(2, 1)
182+
arb2 = arb(1, 1)
183+
with ctx.workprec(100):
184+
actual = arb1 * arb2
185+
true_interval = arb(3, 3) # [0, 6]
186+
assert true_interval in actual
187+
188+
def test_arb_div():
189+
"""`arb.__div__` works as expected."""
190+
arb1 = arb(4, 1)
191+
arb2 = arb(2, 1)
192+
with ctx.workprec(100):
193+
actual = arb1 / arb2
194+
true_interval = arb(4, 1) # [3, 5]
195+
assert true_interval in actual
196+
197+
def test_arb_log():
198+
"""`arb.log` works as expected."""
199+
midpoint = (1 + math.exp(10)) / 2
200+
arb_val = arb(midpoint, midpoint - 1) # [1, exp(10)]
201+
with ctx.workprec(100):
202+
actual = arb_val.log()
203+
true_interval = arb(5, 5) # [0,10]
204+
assert true_interval in actual
205+
206+
def test_arb_exp():
207+
"""`arb.exp` works as expected."""
208+
midpoint = math.log(9) / 2
209+
arb_val = arb(midpoint, midpoint) # [0, log(9)]
210+
with ctx.workprec(100):
211+
actual = arb_val.exp()
212+
true_interval = arb(5, 4) # [1,9]
213+
assert true_interval in actual
214+
215+
def test_arb_max():
216+
"""`arb.max` works as expected."""
217+
arb1 = arb(1.5, 0.5) # [1, 2]
218+
arb2 = arb(1, 2) # [-1, 3]
219+
with ctx.workprec(100):
220+
actual = arb1.max(arb2)
221+
true_interval = arb(2, 1) # [1, 3]
222+
assert true_interval in actual
223+
224+
def test_arb_min():
225+
"""`arb.min` works as expected."""
226+
arb1 = arb(1.5, 0.5) # [1, 2]
227+
arb2 = arb(1, 2) # [-1, 3]
228+
with ctx.workprec(100):
229+
actual = arb1.min(arb2)
230+
true_interval = arb(0.5, 1.5) # [-1, 2]
231+
assert true_interval in actual
232+
233+
def test_arb_abs():
234+
"""`arb.__abs__` works as expected."""
235+
arb_val = arb(1, 2) # [-1,3]
236+
actual = abs(arb_val)
237+
true_interval = arb(1.5, 1.5)
238+
assert true_interval in actual
239+
240+
def test_arb_neg():
241+
"""`arb.neg` works as expected."""
242+
arb_val = arb(1, 2) # [-1,3]
243+
actual = arb_val.neg(exact=True)
244+
true_interval = arb(-2, 1) # [-3,1]
245+
assert true_interval in actual
246+
247+
def test_arb_neg_dunder():
248+
"""`arb.__neg__` works as expected."""
249+
arb_val = arb(1, 2) # [-1,3]
250+
actual = -arb_val
251+
true_interval = arb(-2, 1) # [-3,1]
252+
assert true_interval in actual
253+
254+
def test_arb_sgn():
255+
"""`arb.sgn` works as expected."""
256+
arb1 = arb(1, 0.5) # [0.5,1.5]
257+
arb2 = arb(-1, 0.5) # [-1.5,-0.5]
258+
arb3 = arb(1, 2) # [-1,3]
259+
assert_almost_equal(float(arb1.sgn()), 1)
260+
assert_almost_equal(float(arb2.sgn()), -1)
261+
# arb3 contains both positive and negative numbers
262+
# So, arb_sgn returns [0, 1]
263+
assert_almost_equal(float(arb3.sgn().mid()), 0)
264+
assert_almost_equal(float(arb3.sgn().rad()), 1)
265+
266+
def test_arb_erfinv():
267+
"""`arb.erfinv` works as expected."""
268+
midpoint = (erf(1 / 8) + erf(1 / 16)) / 2
269+
radius = midpoint - erf(1 / 16)
270+
arb_val = arb(midpoint, radius)
271+
with ctx.workprec(100):
272+
actual = arb_val.erfinv()
273+
true_interval = arb(3 / 32, 1 / 32) # [1/16, 1/8]
274+
assert true_interval in actual
275+
276+
def test_arb_erf():
277+
"""`arb.erf` works as expected."""
278+
arb_val = arb(2, 1)
279+
with ctx.workprec(100):
280+
actual = arb_val.erf()
281+
true_interval = arb((erf(1) + erf(3)) / 2, (erf(1) + erf(3)) / 2 - erf(1))
282+
assert true_interval in actual
283+
284+
def test_arb_erfc():
285+
"""`arb.erfc` works as expected."""
286+
arb_val = arb(2, 1)
287+
with ctx.workprec(100):
288+
actual = arb_val.erfc()
289+
true_interval = arb((erfc(1) + erfc(3)) / 2, (erfc(1) + erfc(3)) / 2 - erfc(3))
290+
assert true_interval in actual
291+
292+
def test_arb_const_pi():
293+
"""`arb.pi` works as expected."""
294+
with ctx.workprec(100):
295+
actual = arb.pi()
296+
interval_around_pi = arb(math.pi, 1e-10)
297+
assert actual in interval_around_pi
298+
299+
def test_arb_union():
300+
"""`arb.union` works as expected."""
301+
arb1 = arb(1, 0.5) # [0.5,1.5]
302+
arb2 = arb(3, 0.5) # [2.5,3.5]
303+
with ctx.workprec(100):
304+
actual = arb1.union(arb2)
305+
true_interval = arb(2, 1.5) # [0.5, 3.5]
306+
assert true_interval in actual
307+
308+
def test_arb_sum():
309+
"""`arb.__sum__` works as expected."""
310+
arb1 = arb(1, 0.5) # [0.5,1.5]
311+
arb2 = arb(2, 0.5) # [1.5,2.5]
312+
arb3 = arb(3, 0.5) # [2.5,3.5]
313+
with ctx.workprec(100):
314+
actual = arb1 + arb2 + arb3
315+
true_interval = arb(6, 1.5) # [4.5, 7.5]
316+
assert true_interval in actual
317+
318+
all_tests = [
319+
test_from_int,
320+
test_from_float,
321+
test_from_float_inf,
322+
test_from_man_exp,
323+
test_from_midpoint_radius,
324+
test_is_exact,
325+
test_is_finite,
326+
test_is_nan,
327+
test_lower,
328+
test_upper,
329+
test_contains,
330+
# TODO: Re-enable this if we ever add the ability to hash arbs.
331+
# test_hash,
332+
test_arb_sub,
333+
test_arb_add,
334+
test_arb_mul,
335+
test_arb_div,
336+
test_arb_log,
337+
test_arb_exp,
338+
test_arb_max,
339+
test_arb_min,
340+
test_arb_abs,
341+
test_arb_neg,
342+
test_arb_neg_dunder,
343+
test_arb_sgn,
344+
test_arb_erfinv,
345+
test_arb_erf,
346+
test_arb_erfc,
347+
test_arb_const_pi,
348+
test_arb_union,
349+
test_arb_sum,
350+
]

0 commit comments

Comments
 (0)