1+ """Test for python-flint's `arb` type."""
2+
3+ import math
4+
5+ from flint import arb , ctx
6+
7+ from scipy .special import erf , erfc
8+
9+ def assert_almost_equal (x , y , places = 7 ):
10+ """Helper method for approximate comparisons."""
11+ assert round (x - y , ndigits = places ) == 0
12+
13+ def test_from_int ():
14+ """Tests instantiating `arb`s from ints."""
15+ for val in [
16+ - 42 * 10 ** 9 ,
17+ - 42 * 10 ** 7 ,
18+ - 42 ,
19+ 0 ,
20+ 42 ,
21+ 42 * 10 ** 7 ,
22+ 42 * 10 ** 9 ,
23+ 42 * 10 ** 11 ,
24+ ]:
25+ x = arb (val )
26+ man , exp = x .man_exp ()
27+ assert (man * 2 ** exp ) == val
28+
29+ def test_from_float ():
30+ """Tests instantiating `arb`s from floats."""
31+ for val in [0.0 , 1.1 , - 1.1 , 9.9 * 0.123 , 99.12 ]:
32+ x = arb (val )
33+ man , exp = x .man_exp ()
34+ assert (int (man ) * 2 ** int (exp )) == val
35+
36+ def test_from_float_inf ():
37+ """Tests `arb` works with +/- inf."""
38+ posinf = arb (float ("inf" ))
39+ neginf = arb (float ("-inf" ))
40+
41+ assert not posinf .is_finite ()
42+ assert not neginf .is_finite ()
43+ assert float (posinf ) == float ("inf" )
44+ assert float (neginf ) == float ("-inf" )
45+
46+ def test_from_man_exp ():
47+ """Tests instantiating `arb`s with mantissa and exponent."""
48+ for man , exp in [(2 , 30 ), (4 , 300 ), (5 * 10 ** 2 , 7 ** 8 )]:
49+ x = arb (mid = (man , exp ))
50+ m , e = x .man_exp ()
51+ assert (m * 2 ** e ) == (man * 2 ** exp )
52+
53+ def test_from_midpoint_radius ():
54+ """Tests instantiating `arb`s with midpoint and radius."""
55+ for mid , rad in [(10 , 1 ), (10000 , 5 ), (10 , 1 ), (10 , 1 )]:
56+ mid_arb = arb (mid )
57+ rad_arb = arb (rad )
58+ x = arb (mid_arb , rad_arb )
59+ assert x .mid () == mid_arb
60+ actual_radius = float (x .rad ())
61+ assert_almost_equal (actual_radius , rad )
62+
63+ def test_is_exact ():
64+ """Tests `arb.is_exact`."""
65+ for arb_val , exact in [
66+ (arb (10 ), True ),
67+ (arb (0.01 ), True ),
68+ (arb (- float ("inf" )), True ),
69+ (arb (1 , 0 ), True ),
70+ (arb (1 , 1 ), False ),
71+ ]:
72+ assert arb_val .is_exact () == exact
73+
74+ def test_is_finite ():
75+ """Tests `arb.is_finite`."""
76+ assert not (arb (- float ("inf" )).is_finite ())
77+ assert not (arb (float ("inf" )).is_finite ())
78+ assert (arb (10 ).is_finite ())
79+
80+ def test_is_nan ():
81+ """Tests `arb.is_nan`."""
82+ assert (arb (float ("nan" )).is_nan ())
83+ assert not (arb (0.0 ).is_nan ())
84+
85+ def test_lower ():
86+ """Tests `arb.lower`."""
87+ with ctx .workprec (100 ):
88+ arb_val = arb (1 , 0.5 )
89+ assert_almost_equal (float (arb_val .lower ()), 0.5 )
90+
91+ def test_upper ():
92+ """Tests `arb.upper`."""
93+ with ctx .workprec (100 ):
94+ arb_val = arb (1 , 0.5 )
95+ assert_almost_equal (float (arb_val .upper ()), 1.5 )
96+
97+ def test_contains ():
98+ """`y.__contains__(x)` returns True iff every number in `x` is also in `y`."""
99+ for x , y , expected in [
100+ (
101+ arb (mid = 9 , rad = 1 ),
102+ arb (mid = 10 , rad = 2 ),
103+ True ,
104+ ),
105+ (
106+ arb (mid = 10 , rad = 2 ),
107+ arb (mid = 9 , rad = 1 ),
108+ False ,
109+ ),
110+ (arb (10 ), arb (mid = 9 , rad = 1 ), True ),
111+ (arb (10.1 ), arb (mid = 9 , rad = 1 ), False ),
112+ ]:
113+ assert (x in y ) == expected
114+
115+ # TODO: Re-enable this if we ever add the ability to hash arbs.
116+ # def test_hash():
117+ # """`x` and `y` hash to the same value if they have the same midpoint and radius.
118+
119+ # Args:
120+ # x: An arb.
121+ # y: An arb.
122+ # expected: Whether `x` and `y` should hash to the same value.
123+ # """
124+ # def arb_pi(prec):
125+ # """Helper to calculate arb to a given precision."""
126+ # with ctx.workprec(prec):
127+ # return arb.pi()
128+ # for x, y, expected in [
129+ # (arb(10), arb(10), True),
130+ # (arb(10), arb(11), False),
131+ # (arb(10.0), arb(10), True),
132+ # (
133+ # arb(mid=10, rad=2),
134+ # arb(mid=10, rad=2),
135+ # True,
136+ # ),
137+ # (
138+ # arb(mid=10, rad=2),
139+ # arb(mid=10, rad=3),
140+ # False,
141+ # ),
142+ # (arb_pi(100), arb_pi(100), True),
143+ # (arb_pi(100), arb_pi(1000), False),
144+ # ]:
145+ # assert (hash(x) == hash(y)) == expected
146+
147+
148+
149+ # Tests for arithmetic functions in `flint.arb`.
150+
151+ # NOTE: Correctness of an arb function `F` is specified as follows:
152+
153+ # If `f` is the corresponding real-valued arithmetic function, `F` is correct
154+ # only if, for any Arb X and any real number x in the interval X,
155+ # `f(x)` is in `F(X)`.
156+
157+ # These tests assume arb.__contains__ is correct.
158+
159+ def test_arb_sub ():
160+ """`arb.__sub__` works as expected."""
161+ arb1 = arb (2 , 0.5 )
162+ arb2 = arb (1 , 1 )
163+ with ctx .workprec (100 ):
164+ actual = arb1 - arb2
165+ # Smallest value in diff => 1.5 - 2 = -0.5
166+ # Largest value in diff => 2.5 - 0 = 2.5
167+ true_interval = arb (1 , 1.5 ) # [-0.5, 2.5]
168+ assert true_interval in actual
169+
170+ def test_arb_add ():
171+ """`arb.__add__` works as expected."""
172+ arb1 = arb (2 , 1 )
173+ arb2 = arb (1 , 1 )
174+ with ctx .workprec (100 ):
175+ actual = arb1 + arb2
176+ true_interval = arb (3 , 2 ) # [1, 5]
177+ assert true_interval in actual
178+
179+ def test_arb_mul ():
180+ """`arb.__mul__` works as expected."""
181+ arb1 = arb (2 , 1 )
182+ arb2 = arb (1 , 1 )
183+ with ctx .workprec (100 ):
184+ actual = arb1 * arb2
185+ true_interval = arb (3 , 3 ) # [0, 6]
186+ assert true_interval in actual
187+
188+ def test_arb_div ():
189+ """`arb.__div__` works as expected."""
190+ arb1 = arb (4 , 1 )
191+ arb2 = arb (2 , 1 )
192+ with ctx .workprec (100 ):
193+ actual = arb1 / arb2
194+ true_interval = arb (4 , 1 ) # [3, 5]
195+ assert true_interval in actual
196+
197+ def test_arb_log ():
198+ """`arb.log` works as expected."""
199+ midpoint = (1 + math .exp (10 )) / 2
200+ arb_val = arb (midpoint , midpoint - 1 ) # [1, exp(10)]
201+ with ctx .workprec (100 ):
202+ actual = arb_val .log ()
203+ true_interval = arb (5 , 5 ) # [0,10]
204+ assert true_interval in actual
205+
206+ def test_arb_exp ():
207+ """`arb.exp` works as expected."""
208+ midpoint = math .log (9 ) / 2
209+ arb_val = arb (midpoint , midpoint ) # [0, log(9)]
210+ with ctx .workprec (100 ):
211+ actual = arb_val .exp ()
212+ true_interval = arb (5 , 4 ) # [1,9]
213+ assert true_interval in actual
214+
215+ def test_arb_max ():
216+ """`arb.max` works as expected."""
217+ arb1 = arb (1.5 , 0.5 ) # [1, 2]
218+ arb2 = arb (1 , 2 ) # [-1, 3]
219+ with ctx .workprec (100 ):
220+ actual = arb1 .max (arb2 )
221+ true_interval = arb (2 , 1 ) # [1, 3]
222+ assert true_interval in actual
223+
224+ def test_arb_min ():
225+ """`arb.min` works as expected."""
226+ arb1 = arb (1.5 , 0.5 ) # [1, 2]
227+ arb2 = arb (1 , 2 ) # [-1, 3]
228+ with ctx .workprec (100 ):
229+ actual = arb1 .min (arb2 )
230+ true_interval = arb (0.5 , 1.5 ) # [-1, 2]
231+ assert true_interval in actual
232+
233+ def test_arb_abs ():
234+ """`arb.__abs__` works as expected."""
235+ arb_val = arb (1 , 2 ) # [-1,3]
236+ actual = abs (arb_val )
237+ true_interval = arb (1.5 , 1.5 )
238+ assert true_interval in actual
239+
240+ def test_arb_neg ():
241+ """`arb.neg` works as expected."""
242+ arb_val = arb (1 , 2 ) # [-1,3]
243+ actual = arb_val .neg (exact = True )
244+ true_interval = arb (- 2 , 1 ) # [-3,1]
245+ assert true_interval in actual
246+
247+ def test_arb_neg_dunder ():
248+ """`arb.__neg__` works as expected."""
249+ arb_val = arb (1 , 2 ) # [-1,3]
250+ actual = - arb_val
251+ true_interval = arb (- 2 , 1 ) # [-3,1]
252+ assert true_interval in actual
253+
254+ def test_arb_sgn ():
255+ """`arb.sgn` works as expected."""
256+ arb1 = arb (1 , 0.5 ) # [0.5,1.5]
257+ arb2 = arb (- 1 , 0.5 ) # [-1.5,-0.5]
258+ arb3 = arb (1 , 2 ) # [-1,3]
259+ assert_almost_equal (float (arb1 .sgn ()), 1 )
260+ assert_almost_equal (float (arb2 .sgn ()), - 1 )
261+ # arb3 contains both positive and negative numbers
262+ # So, arb_sgn returns [0, 1]
263+ assert_almost_equal (float (arb3 .sgn ().mid ()), 0 )
264+ assert_almost_equal (float (arb3 .sgn ().rad ()), 1 )
265+
266+ def test_arb_erfinv ():
267+ """`arb.erfinv` works as expected."""
268+ midpoint = (erf (1 / 8 ) + erf (1 / 16 )) / 2
269+ radius = midpoint - erf (1 / 16 )
270+ arb_val = arb (midpoint , radius )
271+ with ctx .workprec (100 ):
272+ actual = arb_val .erfinv ()
273+ true_interval = arb (3 / 32 , 1 / 32 ) # [1/16, 1/8]
274+ assert true_interval in actual
275+
276+ def test_arb_erf ():
277+ """`arb.erf` works as expected."""
278+ arb_val = arb (2 , 1 )
279+ with ctx .workprec (100 ):
280+ actual = arb_val .erf ()
281+ true_interval = arb ((erf (1 ) + erf (3 )) / 2 , (erf (1 ) + erf (3 )) / 2 - erf (1 ))
282+ assert true_interval in actual
283+
284+ def test_arb_erfc ():
285+ """`arb.erfc` works as expected."""
286+ arb_val = arb (2 , 1 )
287+ with ctx .workprec (100 ):
288+ actual = arb_val .erfc ()
289+ true_interval = arb ((erfc (1 ) + erfc (3 )) / 2 , (erfc (1 ) + erfc (3 )) / 2 - erfc (3 ))
290+ assert true_interval in actual
291+
292+ def test_arb_const_pi ():
293+ """`arb.pi` works as expected."""
294+ with ctx .workprec (100 ):
295+ actual = arb .pi ()
296+ interval_around_pi = arb (math .pi , 1e-10 )
297+ assert actual in interval_around_pi
298+
299+ def test_arb_union ():
300+ """`arb.union` works as expected."""
301+ arb1 = arb (1 , 0.5 ) # [0.5,1.5]
302+ arb2 = arb (3 , 0.5 ) # [2.5,3.5]
303+ with ctx .workprec (100 ):
304+ actual = arb1 .union (arb2 )
305+ true_interval = arb (2 , 1.5 ) # [0.5, 3.5]
306+ assert true_interval in actual
307+
308+ def test_arb_sum ():
309+ """`arb.__sum__` works as expected."""
310+ arb1 = arb (1 , 0.5 ) # [0.5,1.5]
311+ arb2 = arb (2 , 0.5 ) # [1.5,2.5]
312+ arb3 = arb (3 , 0.5 ) # [2.5,3.5]
313+ with ctx .workprec (100 ):
314+ actual = arb1 + arb2 + arb3
315+ true_interval = arb (6 , 1.5 ) # [4.5, 7.5]
316+ assert true_interval in actual
317+
318+ all_tests = [
319+ test_from_int ,
320+ test_from_float ,
321+ test_from_float_inf ,
322+ test_from_man_exp ,
323+ test_from_midpoint_radius ,
324+ test_is_exact ,
325+ test_is_finite ,
326+ test_is_nan ,
327+ test_lower ,
328+ test_upper ,
329+ test_contains ,
330+ # TODO: Re-enable this if we ever add the ability to hash arbs.
331+ # test_hash,
332+ test_arb_sub ,
333+ test_arb_add ,
334+ test_arb_mul ,
335+ test_arb_div ,
336+ test_arb_log ,
337+ test_arb_exp ,
338+ test_arb_max ,
339+ test_arb_min ,
340+ test_arb_abs ,
341+ test_arb_neg ,
342+ test_arb_neg_dunder ,
343+ test_arb_sgn ,
344+ test_arb_erfinv ,
345+ test_arb_erf ,
346+ test_arb_erfc ,
347+ test_arb_const_pi ,
348+ test_arb_union ,
349+ test_arb_sum ,
350+ ]
0 commit comments