|
4 | 4 | from collections import defaultdict
|
5 | 5 | from math import sqrt
|
6 | 6 | from operator import attrgetter
|
7 |
| -from typing import Callable, List, Optional, Set, Tuple, Union |
| 7 | +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union |
8 | 8 |
|
9 | 9 | import numpy as np
|
10 | 10 | from scipy.linalg import norm
|
@@ -142,15 +142,20 @@ class _Interval:
|
142 | 142 | def __init__(
|
143 | 143 | self, a: Union[int, float], b: Union[int, float], depth: int, rdepth: int
|
144 | 144 | ) -> None:
|
145 |
| - self.children = [] |
146 |
| - self.data = {} |
| 145 | + self.children: List["_Interval"] = [] |
| 146 | + self.data: Dict[float, float] = {} |
147 | 147 | self.a = a
|
148 | 148 | self.b = b
|
149 | 149 | self.depth = depth
|
150 | 150 | self.rdepth = rdepth
|
151 |
| - self.done_leaves = set() |
152 |
| - self.depth_complete = None |
| 151 | + self.done_leaves: Set["_Interval"] = set() |
| 152 | + self.depth_complete: Optional[int] = None |
153 | 153 | self.removed = False
|
| 154 | + if TYPE_CHECKING: |
| 155 | + self.ndiv: int |
| 156 | + self.parent: Optional["_Interval"] |
| 157 | + self.err: float |
| 158 | + self.c: np.ndarray |
154 | 159 |
|
155 | 160 | @classmethod
|
156 | 161 | def make_first(cls, a: int, b: int, depth: int = 2) -> "_Interval":
|
@@ -234,7 +239,7 @@ def calc_err(self, c_old: np.ndarray) -> float:
|
234 | 239 |
|
235 | 240 | def calc_ndiv(self) -> None:
|
236 | 241 | div = self.parent.c00 and self.c00 / self.parent.c00 > 2
|
237 |
| - self.ndiv += div |
| 242 | + self.ndiv += int(div) |
238 | 243 |
|
239 | 244 | if self.ndiv > ndiv_max and 2 * self.ndiv > self.rdepth:
|
240 | 245 | raise DivergentIntegralError
|
@@ -378,12 +383,14 @@ def __init__(self, function: Callable, bounds: Tuple[int, int], tol: float) -> N
|
378 | 383 | self.bounds = bounds
|
379 | 384 | self.tol = tol
|
380 | 385 | self.max_ivals = 1000
|
381 |
| - self.priority_split = [] |
| 386 | + self.priority_split: List[_Interval] = [] |
382 | 387 | self.data = {}
|
383 | 388 | self.pending_points = set()
|
384 |
| - self._stack = [] |
385 |
| - self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth"))) |
386 |
| - self.ivals = set() |
| 389 | + self._stack: List[float] = [] |
| 390 | + self.x_mapping: Dict[float, SortedSet] = defaultdict( |
| 391 | + lambda: SortedSet([], key=attrgetter("rdepth")) |
| 392 | + ) |
| 393 | + self.ivals: Set[_Interval] = set() |
387 | 394 | ival = _Interval.make_first(*self.bounds)
|
388 | 395 | self.add_ival(ival)
|
389 | 396 | self.first_ival = ival
|
|
0 commit comments