Skip to content

Commit e394995

Browse files
committed
add type annotations for adaptive/learner/integrator_learner.py
1 parent dd95dbb commit e394995

File tree

1 file changed

+51
-29
lines changed

1 file changed

+51
-29
lines changed

adaptive/learner/integrator_learner.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
import sys
44
from collections import defaultdict
5+
from functools import partial
56
from math import sqrt
67
from operator import attrgetter
8+
from typing import Any, Callable, List, Optional, Set, Tuple, Union
79

810
import numpy as np
11+
from numpy import bool_, float64, ndarray, ufunc
912
from scipy.linalg import norm
1013
from sortedcontainers import SortedSet
1114

@@ -30,7 +33,7 @@
3033
)
3134

3235

33-
def _downdate(c, nans, depth):
36+
def _downdate(c: ndarray, nans: List[int], depth: int) -> ndarray:
3437
# This is algorithm 5 from the thesis of Pedro Gonnet.
3538
b = b_def[depth].copy()
3639
m = ns[depth] - 1
@@ -48,7 +51,7 @@ def _downdate(c, nans, depth):
4851
return c
4952

5053

51-
def _zero_nans(fx):
54+
def _zero_nans(fx: ndarray) -> List[int]:
5255
"""Caution: this function modifies fx."""
5356
nans = []
5457
for i in range(len(fx)):
@@ -58,7 +61,7 @@ def _zero_nans(fx):
5861
return nans
5962

6063

61-
def _calc_coeffs(fx, depth):
64+
def _calc_coeffs(fx: ndarray, depth: int) -> ndarray:
6265
"""Caution: this function modifies fx."""
6366
nans = _zero_nans(fx)
6467
c_new = V_inv[depth] @ fx
@@ -138,7 +141,9 @@ class _Interval:
138141
"removed",
139142
]
140143

141-
def __init__(self, a, b, depth, rdepth):
144+
def __init__(
145+
self, a: Union[int, float64], b: Union[int, float64], depth: int, rdepth: int
146+
) -> None:
142147
self.children = []
143148
self.data = {}
144149
self.a = a
@@ -150,15 +155,15 @@ def __init__(self, a, b, depth, rdepth):
150155
self.removed = False
151156

152157
@classmethod
153-
def make_first(cls, a, b, depth=2):
158+
def make_first(cls, a: int, b: int, depth: int = 2) -> "_Interval":
154159
ival = _Interval(a, b, depth, rdepth=1)
155160
ival.ndiv = 0
156161
ival.parent = None
157162
ival.err = sys.float_info.max # needed because inf/2 == inf
158163
return ival
159164

160165
@property
161-
def T(self):
166+
def T(self) -> ndarray:
162167
"""Get the correct shift matrix.
163168
164169
Should only be called on children of a split interval.
@@ -169,24 +174,24 @@ def T(self):
169174
assert left != right
170175
return T_left if left else T_right
171176

172-
def refinement_complete(self, depth):
177+
def refinement_complete(self, depth: int) -> bool:
173178
"""The interval has all the y-values to calculate the intergral."""
174179
if len(self.data) < ns[depth]:
175180
return False
176181
return all(p in self.data for p in self.points(depth))
177182

178-
def points(self, depth=None):
183+
def points(self, depth: Optional[int] = None) -> ndarray:
179184
if depth is None:
180185
depth = self.depth
181186
a = self.a
182187
b = self.b
183188
return (a + b) / 2 + (b - a) * xi[depth] / 2
184189

185-
def refine(self):
190+
def refine(self) -> "_Interval":
186191
self.depth += 1
187192
return self
188193

189-
def split(self):
194+
def split(self) -> List["_Interval"]:
190195
points = self.points()
191196
m = points[len(points) // 2]
192197
ivals = [
@@ -201,10 +206,10 @@ def split(self):
201206

202207
return ivals
203208

204-
def calc_igral(self):
209+
def calc_igral(self) -> None:
205210
self.igral = (self.b - self.a) * self.c[0] / sqrt(2)
206211

207-
def update_heuristic_err(self, value):
212+
def update_heuristic_err(self, value: Union[float64, float]) -> None:
208213
"""Sets the error of an interval using a heuristic (half the error of
209214
the parent) when the actual error cannot be calculated due to its
210215
parents not being finished yet. This error is propagated down to its
@@ -217,7 +222,7 @@ def update_heuristic_err(self, value):
217222
continue
218223
child.update_heuristic_err(value / 2)
219224

220-
def calc_err(self, c_old):
225+
def calc_err(self, c_old: ndarray) -> float:
221226
c_new = self.c
222227
c_diff = np.zeros(max(len(c_old), len(c_new)))
223228
c_diff[: len(c_old)] = c_old
@@ -229,7 +234,7 @@ def calc_err(self, c_old):
229234
child.update_heuristic_err(self.err / 2)
230235
return c_diff
231236

232-
def calc_ndiv(self):
237+
def calc_ndiv(self) -> None:
233238
div = self.parent.c00 and self.c00 / self.parent.c00 > 2
234239
self.ndiv += div
235240

@@ -240,15 +245,17 @@ def calc_ndiv(self):
240245
for child in self.children:
241246
child.update_ndiv_recursively()
242247

243-
def update_ndiv_recursively(self):
248+
def update_ndiv_recursively(self) -> None:
244249
self.ndiv += 1
245250
if self.ndiv > ndiv_max and 2 * self.ndiv > self.rdepth:
246251
raise DivergentIntegralError
247252

248253
for child in self.children:
249254
child.update_ndiv_recursively()
250255

251-
def complete_process(self, depth):
256+
def complete_process(
257+
self, depth: int
258+
) -> Union[Tuple[bool, bool], Tuple[bool, bool_]]:
252259
"""Calculate the integral contribution and error from this interval,
253260
and update the done leaves of all ancestor intervals."""
254261
assert self.depth_complete is None or self.depth_complete == depth - 1
@@ -323,7 +330,7 @@ def complete_process(self, depth):
323330

324331
return force_split, remove
325332

326-
def __repr__(self):
333+
def __repr__(self) -> str:
327334
lst = [
328335
f"(a, b)=({self.a:.5f}, {self.b:.5f})",
329336
f"depth={self.depth}",
@@ -335,7 +342,12 @@ def __repr__(self):
335342

336343

337344
class IntegratorLearner(BaseLearner):
338-
def __init__(self, function, bounds, tol):
345+
def __init__(
346+
self,
347+
function: Union[partial, ufunc, Callable],
348+
bounds: Tuple[int, int],
349+
tol: float,
350+
) -> None:
339351
"""
340352
Parameters
341353
----------
@@ -384,10 +396,10 @@ def __init__(self, function, bounds, tol):
384396
self.first_ival = ival
385397

386398
@property
387-
def approximating_intervals(self):
399+
def approximating_intervals(self) -> Set["_Interval"]:
388400
return self.first_ival.done_leaves
389401

390-
def tell(self, point, value):
402+
def tell(self, point: float64, value: float64) -> None:
391403
if point not in self.x_mapping:
392404
raise ValueError(f"Point {point} doesn't belong to any interval")
393405
self.data[point] = value
@@ -423,7 +435,7 @@ def tell(self, point, value):
423435
def tell_pending(self):
424436
pass
425437

426-
def propagate_removed(self, ival):
438+
def propagate_removed(self, ival: "_Interval") -> None:
427439
def _propagate_removed_down(ival):
428440
ival.removed = True
429441
self.ivals.discard(ival)
@@ -433,7 +445,7 @@ def _propagate_removed_down(ival):
433445

434446
_propagate_removed_down(ival)
435447

436-
def add_ival(self, ival):
448+
def add_ival(self, ival: "_Interval") -> None:
437449
for x in ival.points():
438450
# Update the mappings
439451
self.x_mapping[x].add(ival)
@@ -444,15 +456,19 @@ def add_ival(self, ival):
444456
self._stack.append(x)
445457
self.ivals.add(ival)
446458

447-
def ask(self, n, tell_pending=True):
459+
def ask(
460+
self, n: int, tell_pending: bool = True
461+
) -> Union[Tuple[List[float64], List[float64]], Tuple[List[float64], List[float]]]:
448462
"""Choose points for learners."""
449463
if not tell_pending:
450464
with restore(self):
451465
return self._ask_and_tell_pending(n)
452466
else:
453467
return self._ask_and_tell_pending(n)
454468

455-
def _ask_and_tell_pending(self, n):
469+
def _ask_and_tell_pending(
470+
self, n: int
471+
) -> Union[Tuple[List[float64], List[float64]], Tuple[List[float64], List[float]]]:
456472
points, loss_improvements = self.pop_from_stack(n)
457473
n_left = n - len(points)
458474
while n_left > 0:
@@ -468,7 +484,13 @@ def _ask_and_tell_pending(self, n):
468484

469485
return points, loss_improvements
470486

471-
def pop_from_stack(self, n):
487+
def pop_from_stack(
488+
self, n: int
489+
) -> Union[
490+
Tuple[List[float64], List[float64]],
491+
Tuple[List[Any], List[Any]],
492+
Tuple[List[float64], List[float]],
493+
]:
472494
points = self._stack[:n]
473495
self._stack = self._stack[n:]
474496
loss_improvements = [
@@ -479,7 +501,7 @@ def pop_from_stack(self, n):
479501
def remove_unfinished(self):
480502
pass
481503

482-
def _fill_stack(self):
504+
def _fill_stack(self) -> List[float64]:
483505
# XXX: to-do if all the ivals have err=inf, take the interval
484506
# with the lowest rdepth and no children.
485507
force_split = bool(self.priority_split)
@@ -515,16 +537,16 @@ def _fill_stack(self):
515537
return self._stack
516538

517539
@property
518-
def npoints(self):
540+
def npoints(self) -> int:
519541
"""Number of evaluated points."""
520542
return len(self.data)
521543

522544
@property
523-
def igral(self):
545+
def igral(self) -> float64:
524546
return sum(i.igral for i in self.approximating_intervals)
525547

526548
@property
527-
def err(self):
549+
def err(self) -> float64:
528550
if self.approximating_intervals:
529551
err = sum(i.err for i in self.approximating_intervals)
530552
if err > sys.float_info.max:

0 commit comments

Comments
 (0)