Skip to content

Commit 8f8c5dc

Browse files
committed
use data types from np namespace instead importing it separately
1 parent 0f27eb8 commit 8f8c5dc

13 files changed

+271
-230
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from typing import Any, Callable, Dict, List, Set, Tuple, Union
88

99
import numpy as np
10-
from numpy import float64, int64
1110

1211
from adaptive.learner.base_learner import BaseLearner
1312
from adaptive.notebook_integration import ensure_holoviews
1413
from adaptive.utils import cache_latest, named_product, restore
1514

1615

17-
def dispatch(child_functions: List[Callable], arg: Any,) -> Union[int, float64, float]:
16+
def dispatch(
17+
child_functions: List[Callable], arg: Any,
18+
) -> Union[int, np.float64, float]:
1819
index, x = arg
1920
return child_functions[index](x)
2021

@@ -165,7 +166,7 @@ def _ask_and_tell_based_on_loss_improvements(self, n: int) -> Any:
165166
def _ask_and_tell_based_on_loss(
166167
self, n: int
167168
) -> Union[
168-
Tuple[List[Tuple[int, float]], List[float64]],
169+
Tuple[List[Tuple[int, float]], List[np.float64]],
169170
Tuple[List[Union[Tuple[int, int], Tuple[int, float]]], List[float]],
170171
Tuple[List[Tuple[int, int]], List[float]],
171172
]:
@@ -192,9 +193,9 @@ def _ask_and_tell_based_on_loss(
192193
def _ask_and_tell_based_on_npoints(
193194
self, n: int
194195
) -> Union[
195-
Tuple[List[Union[Tuple[int64, int], Tuple[int64, float]]], List[float]],
196-
Tuple[List[Tuple[int64, float]], List[float64]],
197-
Tuple[List[Tuple[int64, int]], List[float]],
196+
Tuple[List[Union[Tuple[np.int64, int], Tuple[np.int64, float]]], List[float]],
197+
Tuple[List[Tuple[np.int64, float]], List[np.float64]],
198+
Tuple[List[Tuple[np.int64, int]], List[float]],
198199
]:
199200
selected = [] # tuples ((learner_index, point), loss_improvement)
200201
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
@@ -214,7 +215,7 @@ def _ask_and_tell_based_on_npoints(
214215
def _ask_and_tell_based_on_cycle(
215216
self, n: int
216217
) -> Union[
217-
Tuple[List[Tuple[int, float]], List[float64]],
218+
Tuple[List[Tuple[int, float]], List[np.float64]],
218219
Tuple[List[Union[Tuple[int, int], Tuple[int, float]]], List[float]],
219220
Tuple[List[Tuple[int, int]], List[float]],
220221
]:
@@ -240,7 +241,9 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[Any], List[float]
240241
return self._ask_and_tell(n)
241242

242243
def tell(
243-
self, x: Any, y: Union[int, float64, float, Tuple[int, int], Tuple[int64, int]]
244+
self,
245+
x: Any,
246+
y: Union[int, np.float64, float, Tuple[int, int], Tuple[np.int64, int]],
244247
) -> None:
245248
index, x = x
246249
self._ask_cache.pop(index, None)
@@ -256,7 +259,7 @@ def tell_pending(self, x: Any) -> None:
256259

257260
def _losses(
258261
self, real: bool = True
259-
) -> Union[List[float], List[float64], List[Union[float, float64]]]:
262+
) -> Union[List[float], List[np.float64], List[Union[float, np.float64]]]:
260263
losses = []
261264
loss_dict = self._loss if real else self._pending_loss
262265

@@ -268,7 +271,7 @@ def _losses(
268271
return losses
269272

270273
@cache_latest
271-
def loss(self, real: bool = True) -> Union[float64, float]:
274+
def loss(self, real: bool = True) -> Union[np.float64, float]:
272275
losses = self._losses(real)
273276
return max(losses)
274277

adaptive/learner/integrator_coeffs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import numpy as np
88
import scipy.linalg
9-
from numpy import ndarray
109

1110

1211
def legendre(n: int) -> List[List[Fraction]]:
@@ -30,7 +29,7 @@ def legendre(n: int) -> List[List[Fraction]]:
3029
return result
3130

3231

33-
def newton(n: int) -> ndarray:
32+
def newton(n: int) -> np.ndarray:
3433
"""Compute the monomial coefficients of the Newton polynomial over the
3534
nodes of the n-point Clenshaw-Curtis quadrature rule.
3635
"""
@@ -108,7 +107,7 @@ def scalar_product(a: List[Fraction], b: List[Fraction]) -> Fraction:
108107
return 2 * sum(c[i] / (i + 1) for i in range(0, lc, 2))
109108

110109

111-
def calc_bdef(ns: Tuple[int, int, int, int]) -> List[ndarray]:
110+
def calc_bdef(ns: Tuple[int, int, int, int]) -> List[np.ndarray]:
112111
"""Calculate the decompositions of Newton polynomials (over the nodes
113112
of the n-point Clenshaw-Curtis quadrature rule) in terms of
114113
Legandre polynomials.
@@ -134,7 +133,7 @@ def calc_bdef(ns: Tuple[int, int, int, int]) -> List[ndarray]:
134133
return result
135134

136135

137-
def calc_V(x: ndarray, n: int) -> ndarray:
136+
def calc_V(x: np.ndarray, n: int) -> np.ndarray:
138137
V = [np.ones(x.shape), x.copy()]
139138
for i in range(2, n):
140139
V.append((2 * i - 1) / i * x * V[-1] - (i - 1) / i * V[-2])

adaptive/learner/integrator_learner.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Callable, List, Optional, Set, Tuple, Union
99

1010
import numpy as np
11-
from numpy import bool_, float64, ndarray, ufunc
11+
from numpy import ufunc
1212
from scipy.linalg import norm
1313
from sortedcontainers import SortedSet
1414

@@ -33,7 +33,7 @@
3333
)
3434

3535

36-
def _downdate(c: ndarray, nans: List[int], depth: int) -> ndarray:
36+
def _downdate(c: np.ndarray, nans: List[int], depth: int) -> np.ndarray:
3737
# This is algorithm 5 from the thesis of Pedro Gonnet.
3838
b = b_def[depth].copy()
3939
m = ns[depth] - 1
@@ -51,7 +51,7 @@ def _downdate(c: ndarray, nans: List[int], depth: int) -> ndarray:
5151
return c
5252

5353

54-
def _zero_nans(fx: ndarray) -> List[int]:
54+
def _zero_nans(fx: np.ndarray) -> List[int]:
5555
"""Caution: this function modifies fx."""
5656
nans = []
5757
for i in range(len(fx)):
@@ -61,7 +61,7 @@ def _zero_nans(fx: ndarray) -> List[int]:
6161
return nans
6262

6363

64-
def _calc_coeffs(fx: ndarray, depth: int) -> ndarray:
64+
def _calc_coeffs(fx: np.ndarray, depth: int) -> np.ndarray:
6565
"""Caution: this function modifies fx."""
6666
nans = _zero_nans(fx)
6767
c_new = V_inv[depth] @ fx
@@ -142,7 +142,11 @@ class _Interval:
142142
]
143143

144144
def __init__(
145-
self, a: Union[int, float64], b: Union[int, float64], depth: int, rdepth: int
145+
self,
146+
a: Union[int, np.float64],
147+
b: Union[int, np.float64],
148+
depth: int,
149+
rdepth: int,
146150
) -> None:
147151
self.children = []
148152
self.data = {}
@@ -163,7 +167,7 @@ def make_first(cls, a: int, b: int, depth: int = 2) -> "_Interval":
163167
return ival
164168

165169
@property
166-
def T(self) -> ndarray:
170+
def T(self) -> np.ndarray:
167171
"""Get the correct shift matrix.
168172
169173
Should only be called on children of a split interval.
@@ -180,7 +184,7 @@ def refinement_complete(self, depth: int) -> bool:
180184
return False
181185
return all(p in self.data for p in self.points(depth))
182186

183-
def points(self, depth: Optional[int] = None) -> ndarray:
187+
def points(self, depth: Optional[int] = None) -> np.ndarray:
184188
if depth is None:
185189
depth = self.depth
186190
a = self.a
@@ -209,7 +213,7 @@ def split(self) -> List["_Interval"]:
209213
def calc_igral(self) -> None:
210214
self.igral = (self.b - self.a) * self.c[0] / sqrt(2)
211215

212-
def update_heuristic_err(self, value: Union[float64, float]) -> None:
216+
def update_heuristic_err(self, value: Union[np.float64, float]) -> None:
213217
"""Sets the error of an interval using a heuristic (half the error of
214218
the parent) when the actual error cannot be calculated due to its
215219
parents not being finished yet. This error is propagated down to its
@@ -222,7 +226,7 @@ def update_heuristic_err(self, value: Union[float64, float]) -> None:
222226
continue
223227
child.update_heuristic_err(value / 2)
224228

225-
def calc_err(self, c_old: ndarray) -> float:
229+
def calc_err(self, c_old: np.ndarray) -> float:
226230
c_new = self.c
227231
c_diff = np.zeros(max(len(c_old), len(c_new)))
228232
c_diff[: len(c_old)] = c_old
@@ -255,7 +259,7 @@ def update_ndiv_recursively(self) -> None:
255259

256260
def complete_process(
257261
self, depth: int
258-
) -> Union[Tuple[bool, bool], Tuple[bool, bool_]]:
262+
) -> Union[Tuple[bool, bool], Tuple[bool, np.bool_]]:
259263
"""Calculate the integral contribution and error from this interval,
260264
and update the done leaves of all ancestor intervals."""
261265
assert self.depth_complete is None or self.depth_complete == depth - 1
@@ -399,7 +403,7 @@ def __init__(
399403
def approximating_intervals(self) -> Set["_Interval"]:
400404
return self.first_ival.done_leaves
401405

402-
def tell(self, point: float64, value: float64) -> None:
406+
def tell(self, point: np.float64, value: np.float64) -> None:
403407
if point not in self.x_mapping:
404408
raise ValueError(f"Point {point} doesn't belong to any interval")
405409
self.data[point] = value
@@ -458,7 +462,9 @@ def add_ival(self, ival: "_Interval") -> None:
458462

459463
def ask(
460464
self, n: int, tell_pending: bool = True
461-
) -> Union[Tuple[List[float64], List[float64]], Tuple[List[float64], List[float]]]:
465+
) -> Union[
466+
Tuple[List[np.float64], List[np.float64]], Tuple[List[np.float64], List[float]]
467+
]:
462468
"""Choose points for learners."""
463469
if not tell_pending:
464470
with restore(self):
@@ -468,7 +474,9 @@ def ask(
468474

469475
def _ask_and_tell_pending(
470476
self, n: int
471-
) -> Union[Tuple[List[float64], List[float64]], Tuple[List[float64], List[float]]]:
477+
) -> Union[
478+
Tuple[List[np.float64], List[np.float64]], Tuple[List[np.float64], List[float]]
479+
]:
472480
points, loss_improvements = self.pop_from_stack(n)
473481
n_left = n - len(points)
474482
while n_left > 0:
@@ -487,9 +495,9 @@ def _ask_and_tell_pending(
487495
def pop_from_stack(
488496
self, n: int
489497
) -> Union[
490-
Tuple[List[float64], List[float64]],
498+
Tuple[List[np.float64], List[np.float64]],
491499
Tuple[List[Any], List[Any]],
492-
Tuple[List[float64], List[float]],
500+
Tuple[List[np.float64], List[float]],
493501
]:
494502
points = self._stack[:n]
495503
self._stack = self._stack[n:]
@@ -501,7 +509,7 @@ def pop_from_stack(
501509
def remove_unfinished(self):
502510
pass
503511

504-
def _fill_stack(self) -> List[float64]:
512+
def _fill_stack(self) -> List[np.float64]:
505513
# XXX: to-do if all the ivals have err=inf, take the interval
506514
# with the lowest rdepth and no children.
507515
force_split = bool(self.priority_split)
@@ -542,11 +550,11 @@ def npoints(self) -> int:
542550
return len(self.data)
543551

544552
@property
545-
def igral(self) -> float64:
553+
def igral(self) -> np.float64:
546554
return sum(i.igral for i in self.approximating_intervals)
547555

548556
@property
549-
def err(self) -> float64:
557+
def err(self) -> np.float64:
550558
if self.approximating_intervals:
551559
err = sum(i.err for i in self.approximating_intervals)
552560
if err > sys.float_info.max:

0 commit comments

Comments
 (0)