Skip to content

Commit 5de47bb

Browse files
committed
add type annotations for adaptive/tests/algorithm_4.py
1 parent d8d2fba commit 5de47bb

File tree

1 file changed

+43
-28
lines changed

1 file changed

+43
-28
lines changed

adaptive/tests/algorithm_4.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,40 @@
22
# Copyright 2017 Christoph Groth
33

44
from collections import defaultdict
5-
from fractions import Fraction as Frac
5+
from fractions import Fraction
6+
from functools import partial
7+
from typing import Callable, List, Tuple, Union
68

79
import numpy as np
10+
from numpy import float64, ndarray
811
from numpy.testing import assert_allclose
912
from scipy.linalg import inv, norm
1013

1114
eps = np.spacing(1)
1215

1316

14-
def legendre(n):
17+
def legendre(n: int) -> List[List[Fraction]]:
1518
"""Return the first n Legendre polynomials.
1619
1720
The polynomials have *standard* normalization, i.e.
1821
int_{-1}^1 dx L_n(x) L_m(x) = delta(m, n) * 2 / (2 * n + 1).
1922
2023
The return value is a list of list of fraction.Fraction instances.
2124
"""
22-
result = [[Frac(1)], [Frac(0), Frac(1)]]
25+
result = [[Fraction(1)], [Fraction(0), Fraction(1)]]
2326
if n <= 2:
2427
return result[:n]
2528
for i in range(2, n):
2629
# Use Bonnet's recursion formula.
27-
new = (i + 1) * [Frac(0)]
30+
new = (i + 1) * [Fraction(0)]
2831
new[1:] = (r * (2 * i - 1) for r in result[-1])
2932
new[:-2] = (n - r * (i - 1) for n, r in zip(new[:-2], result[-2]))
3033
new[:] = (n / i for n in new)
3134
result.append(new)
3235
return result
3336

3437

35-
def newton(n):
38+
def newton(n: int) -> ndarray:
3639
"""Compute the monomial coefficients of the Newton polynomial over the
3740
nodes of the n-point Clenshaw-Curtis quadrature rule.
3841
"""
@@ -89,7 +92,7 @@ def newton(n):
8992
return cf
9093

9194

92-
def scalar_product(a, b):
95+
def scalar_product(a: List[Fraction], b: List[Fraction]) -> Fraction:
9396
"""Compute the polynomial scalar product int_-1^1 dx a(x) b(x).
9497
9598
The args must be sequences of polynomial coefficients. This
@@ -110,7 +113,7 @@ def scalar_product(a, b):
110113
return 2 * sum(c[i] / (i + 1) for i in range(0, lc, 2))
111114

112115

113-
def calc_bdef(ns):
116+
def calc_bdef(ns: Tuple[int, int, int, int]) -> List[ndarray]:
114117
"""Calculate the decompositions of Newton polynomials (over the nodes
115118
of the n-point Clenshaw-Curtis quadrature rule) in terms of
116119
Legandre polynomials.
@@ -123,7 +126,7 @@ def calc_bdef(ns):
123126
result = []
124127
for n in ns:
125128
poly = []
126-
a = list(map(Frac, newton(n)))
129+
a = list(map(Fraction, newton(n)))
127130
for b in legs[: n + 1]:
128131
igral = scalar_product(a, b)
129132

@@ -145,7 +148,7 @@ def calc_bdef(ns):
145148
b_def = calc_bdef(n)
146149

147150

148-
def calc_V(xi, n):
151+
def calc_V(xi: ndarray, n: int) -> ndarray:
149152
V = [np.ones(xi.shape), xi.copy()]
150153
for i in range(2, n):
151154
V.append((2 * i - 1) / i * xi * V[-1] - (i - 1) / i * V[-2])
@@ -183,7 +186,7 @@ def calc_V(xi, n):
183186
gamma = np.concatenate([[0, 0], np.sqrt(k[2:] ** 2 / (4 * k[2:] ** 2 - 1))])
184187

185188

186-
def _downdate(c, nans, depth):
189+
def _downdate(c: ndarray, nans: List[int], depth: int) -> None:
187190
# This is algorithm 5 from the thesis of Pedro Gonnet.
188191
b = b_def[depth].copy()
189192
m = n[depth] - 1
@@ -200,7 +203,7 @@ def _downdate(c, nans, depth):
200203
m -= 1
201204

202205

203-
def _zero_nans(fx):
206+
def _zero_nans(fx: ndarray) -> List[int]:
204207
nans = []
205208
for i in range(len(fx)):
206209
if not np.isfinite(fx[i]):
@@ -209,7 +212,7 @@ def _zero_nans(fx):
209212
return nans
210213

211214

212-
def _calc_coeffs(fx, depth):
215+
def _calc_coeffs(fx: ndarray, depth: int) -> ndarray:
213216
"""Caution: this function modifies fx."""
214217
nans = _zero_nans(fx)
215218
c_new = V_inv[depth] @ fx
@@ -220,7 +223,7 @@ def _calc_coeffs(fx, depth):
220223

221224

222225
class DivergentIntegralError(ValueError):
223-
def __init__(self, msg, igral, err, nr_points):
226+
def __init__(self, msg: str, igral: float64, err: None, nr_points: int) -> None:
224227
self.igral = igral
225228
self.err = err
226229
self.nr_points = nr_points
@@ -230,19 +233,23 @@ def __init__(self, msg, igral, err, nr_points):
230233
class _Interval:
231234
__slots__ = ["a", "b", "c", "fx", "igral", "err", "depth", "rdepth", "ndiv", "c00"]
232235

233-
def __init__(self, a, b, depth, rdepth):
236+
def __init__(
237+
self, a: Union[int, float], b: Union[int, float], depth: int, rdepth: int
238+
) -> None:
234239
self.a = a
235240
self.b = b
236241
self.depth = depth
237242
self.rdepth = rdepth
238243

239-
def points(self):
244+
def points(self) -> ndarray:
240245
a = self.a
241246
b = self.b
242247
return (a + b) / 2 + (b - a) * xi[self.depth] / 2
243248

244249
@classmethod
245-
def make_first(cls, f, a, b, depth=2):
250+
def make_first(
251+
cls, f: Union[partial, Callable], a: int, b: int, depth: int = 2
252+
) -> Tuple["_Interval", int]:
246253
ival = _Interval(a, b, depth, 1)
247254
fx = f(ival.points())
248255
ival.c = _calc_coeffs(fx, depth)
@@ -251,7 +258,7 @@ def make_first(cls, f, a, b, depth=2):
251258
ival.ndiv = 0
252259
return ival, n[depth]
253260

254-
def calc_igral_and_err(self, c_old):
261+
def calc_igral_and_err(self, c_old: ndarray) -> float:
255262
self.c = c_new = _calc_coeffs(self.fx, self.depth)
256263
c_diff = np.zeros(max(len(c_old), len(c_new)))
257264
c_diff[: len(c_old)] = c_old
@@ -262,7 +269,9 @@ def calc_igral_and_err(self, c_old):
262269
self.err = w * c_diff
263270
return c_diff
264271

265-
def split(self, f):
272+
def split(
273+
self, f: Union[partial, Callable]
274+
) -> Union[Tuple[Tuple[float, float, float], int], Tuple[List["_Interval"], int]]:
266275
m = (self.a + self.b) / 2
267276
f_center = self.fx[(len(self.fx) - 1) // 2]
268277

@@ -287,7 +296,7 @@ def split(self, f):
287296

288297
return ivals, nr_points
289298

290-
def refine(self, f):
299+
def refine(self, f: Union[partial, Callable]) -> Tuple[ndarray, bool, int]:
291300
"""Increase degree of interval."""
292301
self.depth = depth = self.depth + 1
293302
points = self.points()
@@ -299,7 +308,9 @@ def refine(self, f):
299308
return points, split, n[depth] - n[depth - 1]
300309

301310

302-
def algorithm_4(f, a, b, tol, N_loops=int(1e9)):
311+
def algorithm_4(
312+
f: Union[partial, Callable], a: int, b: int, tol: float, N_loops: int = int(1e9)
313+
) -> Tuple[float64, float, int, List["_Interval"]]:
303314
"""ALGORITHM_4 evaluates an integral using adaptive quadrature. The
304315
algorithm uses Clenshaw-Curtis quadrature rules of increasing
305316
degree in each interval and bisects the interval if either the
@@ -403,37 +414,39 @@ def algorithm_4(f, a, b, tol, N_loops=int(1e9)):
403414
return igral, err, nr_points, ivals
404415

405416

406-
################ Tests ################
417+
# ############### Tests ################
407418

408419

409-
def f0(x):
420+
def f0(x: Union[float64, ndarray]) -> Union[float64, ndarray]:
410421
return x * np.sin(1 / x) * np.sqrt(abs(1 - x))
411422

412423

413-
def f7(x):
424+
def f7(x: Union[float64, ndarray]) -> Union[float64, ndarray]:
414425
return x ** -0.5
415426

416427

417-
def f24(x):
428+
def f24(x: Union[float64, ndarray]) -> Union[float64, ndarray]:
418429
return np.floor(np.exp(x))
419430

420431

421-
def f21(x):
432+
def f21(x: Union[float64, ndarray]) -> Union[float64, ndarray]:
422433
y = 0
423434
for i in range(1, 4):
424435
y += 1 / np.cosh(20 ** i * (x - 2 * i / 10))
425436
return y
426437

427438

428-
def f63(x, alpha, beta):
439+
def f63(
440+
x: Union[float64, ndarray], alpha: float, beta: float
441+
) -> Union[float64, ndarray]:
429442
return abs(x - beta) ** alpha
430443

431444

432445
def F63(x, alpha, beta):
433446
return (x - beta) * abs(x - beta) ** alpha / (alpha + 1)
434447

435448

436-
def fdiv(x):
449+
def fdiv(x: Union[float64, ndarray]) -> Union[float64, ndarray]:
437450
return abs(x - 0.987654321) ** -1.1
438451

439452

@@ -461,7 +474,9 @@ def test_scalar_product(n=33):
461474
selection = [0, 5, 7, n - 1]
462475
for i in selection:
463476
for j in selection:
464-
assert scalar_product(legs[i], legs[j]) == ((i == j) and Frac(2, 2 * i + 1))
477+
assert scalar_product(legs[i], legs[j]) == (
478+
(i == j) and Fraction(2, 2 * i + 1)
479+
)
465480

466481

467482
def simple_newton(n):

0 commit comments

Comments
 (0)