Skip to content

Commit f4ecae6

Browse files
committed
type hint fixes for adaptive/learner/integrator_learner.py
1 parent 7bc15eb commit f4ecae6

File tree

1 file changed

+11
-34
lines changed

1 file changed

+11
-34
lines changed

adaptive/learner/integrator_learner.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33
import sys
44
from collections import defaultdict
5-
from functools import partial
65
from math import sqrt
76
from operator import attrgetter
8-
from typing import Any, Callable, List, Optional, Set, Tuple, Union
7+
from typing import Callable, List, Optional, Set, Tuple, Union
98

109
import numpy as np
11-
from numpy import ufunc
1210
from scipy.linalg import norm
1311
from sortedcontainers import SortedSet
1412

@@ -142,11 +140,7 @@ class _Interval:
142140
]
143141

144142
def __init__(
145-
self,
146-
a: Union[int, np.float64],
147-
b: Union[int, np.float64],
148-
depth: int,
149-
rdepth: int,
143+
self, a: Union[int, float], b: Union[int, float], depth: int, rdepth: int,
150144
) -> None:
151145
self.children = []
152146
self.data = {}
@@ -213,7 +207,7 @@ def split(self) -> List["_Interval"]:
213207
def calc_igral(self) -> None:
214208
self.igral = (self.b - self.a) * self.c[0] / sqrt(2)
215209

216-
def update_heuristic_err(self, value: Union[np.float64, float]) -> None:
210+
def update_heuristic_err(self, value: float) -> None:
217211
"""Sets the error of an interval using a heuristic (half the error of
218212
the parent) when the actual error cannot be calculated due to its
219213
parents not being finished yet. This error is propagated down to its
@@ -347,10 +341,7 @@ def __repr__(self) -> str:
347341

348342
class IntegratorLearner(BaseLearner):
349343
def __init__(
350-
self,
351-
function: Union[partial, ufunc, Callable],
352-
bounds: Tuple[int, int],
353-
tol: float,
344+
self, function: Callable, bounds: Tuple[int, int], tol: float,
354345
) -> None:
355346
"""
356347
Parameters
@@ -403,7 +394,7 @@ def __init__(
403394
def approximating_intervals(self) -> Set["_Interval"]:
404395
return self.first_ival.done_leaves
405396

406-
def tell(self, point: np.float64, value: np.float64) -> None:
397+
def tell(self, point: float, value: float) -> None:
407398
if point not in self.x_mapping:
408399
raise ValueError(f"Point {point} doesn't belong to any interval")
409400
self.data[point] = value
@@ -460,23 +451,15 @@ def add_ival(self, ival: "_Interval") -> None:
460451
self._stack.append(x)
461452
self.ivals.add(ival)
462453

463-
def ask(
464-
self, n: int, tell_pending: bool = True
465-
) -> Union[
466-
Tuple[List[np.float64], List[np.float64]], Tuple[List[np.float64], List[float]]
467-
]:
454+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[float]]:
468455
"""Choose points for learners."""
469456
if not tell_pending:
470457
with restore(self):
471458
return self._ask_and_tell_pending(n)
472459
else:
473460
return self._ask_and_tell_pending(n)
474461

475-
def _ask_and_tell_pending(
476-
self, n: int
477-
) -> Union[
478-
Tuple[List[np.float64], List[np.float64]], Tuple[List[np.float64], List[float]]
479-
]:
462+
def _ask_and_tell_pending(self, n: int) -> Tuple[List[float], List[float]]:
480463
points, loss_improvements = self.pop_from_stack(n)
481464
n_left = n - len(points)
482465
while n_left > 0:
@@ -492,13 +475,7 @@ def _ask_and_tell_pending(
492475

493476
return points, loss_improvements
494477

495-
def pop_from_stack(
496-
self, n: int
497-
) -> Union[
498-
Tuple[List[np.float64], List[np.float64]],
499-
Tuple[List[Any], List[Any]],
500-
Tuple[List[np.float64], List[float]],
501-
]:
478+
def pop_from_stack(self, n: int) -> Tuple[List[float], List[float]]:
502479
points = self._stack[:n]
503480
self._stack = self._stack[n:]
504481
loss_improvements = [
@@ -509,7 +486,7 @@ def pop_from_stack(
509486
def remove_unfinished(self):
510487
pass
511488

512-
def _fill_stack(self) -> List[np.float64]:
489+
def _fill_stack(self) -> List[float]:
513490
# XXX: to-do if all the ivals have err=inf, take the interval
514491
# with the lowest rdepth and no children.
515492
force_split = bool(self.priority_split)
@@ -550,11 +527,11 @@ def npoints(self) -> int:
550527
return len(self.data)
551528

552529
@property
553-
def igral(self) -> np.float64:
530+
def igral(self) -> float:
554531
return sum(i.igral for i in self.approximating_intervals)
555532

556533
@property
557-
def err(self) -> np.float64:
534+
def err(self) -> float:
558535
if self.approximating_intervals:
559536
err = sum(i.err for i in self.approximating_intervals)
560537
if err > sys.float_info.max:

0 commit comments

Comments
 (0)