Skip to content

Commit bb3d419

Browse files
committed
fix all mypy issues in adaptive/learner/integrator_learner.py
1 parent fb457d5 commit bb3d419

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

adaptive/learner/integrator_learner.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import defaultdict
55
from math import sqrt
66
from operator import attrgetter
7-
from typing import Callable, List, Optional, Set, Tuple, Union
7+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union
88

99
import numpy as np
1010
from scipy.linalg import norm
@@ -142,15 +142,20 @@ class _Interval:
142142
def __init__(
143143
self, a: Union[int, float], b: Union[int, float], depth: int, rdepth: int
144144
) -> None:
145-
self.children = []
146-
self.data = {}
145+
self.children: List["_Interval"] = []
146+
self.data: Dict[float, float] = {}
147147
self.a = a
148148
self.b = b
149149
self.depth = depth
150150
self.rdepth = rdepth
151-
self.done_leaves = set()
152-
self.depth_complete = None
151+
self.done_leaves: Set["_Interval"] = set()
152+
self.depth_complete: Optional[int] = None
153153
self.removed = False
154+
if TYPE_CHECKING:
155+
self.ndiv: int
156+
self.parent: Optional["_Interval"]
157+
self.err: float
158+
self.c: np.ndarray
154159

155160
@classmethod
156161
def make_first(cls, a: int, b: int, depth: int = 2) -> "_Interval":
@@ -234,7 +239,7 @@ def calc_err(self, c_old: np.ndarray) -> float:
234239

235240
def calc_ndiv(self) -> None:
236241
div = self.parent.c00 and self.c00 / self.parent.c00 > 2
237-
self.ndiv += div
242+
self.ndiv += int(div)
238243

239244
if self.ndiv > ndiv_max and 2 * self.ndiv > self.rdepth:
240245
raise DivergentIntegralError
@@ -378,12 +383,14 @@ def __init__(self, function: Callable, bounds: Tuple[int, int], tol: float) -> N
378383
self.bounds = bounds
379384
self.tol = tol
380385
self.max_ivals = 1000
381-
self.priority_split = []
386+
self.priority_split: List[_Interval] = []
382387
self.data = {}
383388
self.pending_points = set()
384-
self._stack = []
385-
self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth")))
386-
self.ivals = set()
389+
self._stack: List[float] = []
390+
self.x_mapping: Dict[float, SortedSet] = defaultdict(
391+
lambda: SortedSet([], key=attrgetter("rdepth"))
392+
)
393+
self.ivals: Set[_Interval] = set()
387394
ival = _Interval.make_first(*self.bounds)
388395
self.add_ival(ival)
389396
self.first_ival = ival

0 commit comments

Comments
 (0)