Skip to content

Commit c35bea3

Browse files
committed
add 'from __future__ import annotations'
1 parent 3aa7e84 commit c35bea3

14 files changed

+207
-223
lines changed

adaptive/learner/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from contextlib import suppress
24

35
from adaptive.learner.average_learner import AverageLearner

adaptive/learner/average_learner.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from __future__ import annotations
2+
13
from math import sqrt
2-
from typing import Callable, Dict, List, Optional, Tuple
4+
from typing import Callable
35

46
import cloudpickle
57
import numpy as np
@@ -38,8 +40,8 @@ class AverageLearner(BaseLearner):
3840
def __init__(
3941
self,
4042
function: Callable[[int], Real],
41-
atol: Optional[float] = None,
42-
rtol: Optional[float] = None,
43+
atol: float | None = None,
44+
rtol: float | None = None,
4345
min_npoints: int = 2,
4446
) -> None:
4547
if atol is None and rtol is None:
@@ -68,7 +70,7 @@ def to_numpy(self):
6870
"""Data as NumPy array of size (npoints, 2) with seeds and values."""
6971
return np.array(sorted(self.data.items()))
7072

71-
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[int], List[Float]]:
73+
def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]]:
7274
points = list(range(self.n_requested, self.n_requested + n))
7375

7476
if any(p in self.data or p in self.pending_points for p in points):
@@ -159,10 +161,10 @@ def plot(self):
159161
vals = hv.Points(vals)
160162
return hv.operation.histogram(vals, num_bins=num_bins, dimension="y")
161163

162-
def _get_data(self) -> Tuple[Dict[int, Real], int, Real, Real]:
164+
def _get_data(self) -> tuple[dict[int, Real], int, Real, Real]:
163165
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)
164166

165-
def _set_data(self, data: Tuple[Dict[int, Real], int, Real, Real]) -> None:
167+
def _set_data(self, data: tuple[dict[int, Real], int, Real, Real]) -> None:
166168
self.data, self.npoints, self.sum_f, self.sum_f_sq = data
167169

168170
def __getstate__(self):

adaptive/learner/average_learner1D.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from __future__ import annotations
2+
13
import math
24
import sys
35
from collections import defaultdict
46
from copy import deepcopy
57
from math import hypot
6-
from typing import Callable, DefaultDict, Dict, List, Optional, Sequence, Set, Tuple
8+
from typing import Callable, DefaultDict, List, Sequence, Tuple
79

810
import numpy as np
911
import scipy.stats
@@ -17,7 +19,7 @@
1719
Point = Tuple[int, Real]
1820
Points = List[Point]
1921

20-
__all__: List[str] = ["AverageLearner1D"]
22+
__all__: list[str] = ["AverageLearner1D"]
2123

2224

2325
class AverageLearner1D(Learner1D):
@@ -65,11 +67,10 @@ class AverageLearner1D(Learner1D):
6567

6668
def __init__(
6769
self,
68-
function: Callable[[Tuple[int, Real]], Real],
69-
bounds: Tuple[Real, Real],
70-
loss_per_interval: Optional[
71-
Callable[[Sequence[Real], Sequence[Real]], float]
72-
] = None,
70+
function: Callable[[tuple[int, Real]], Real],
71+
bounds: tuple[Real, Real],
72+
loss_per_interval: None
73+
| (Callable[[Sequence[Real], Sequence[Real]], float]) = None,
7374
delta: float = 0.2,
7475
alpha: float = 0.005,
7576
neighbor_sampling: float = 0.3,
@@ -105,7 +106,7 @@ def __init__(
105106
self._number_samples = SortedDict()
106107
# This set contains the points x that have less than min_samples
107108
# samples or less than a (neighbor_sampling*100)% of their neighbors
108-
self._undersampled_points: Set[Real] = set()
109+
self._undersampled_points: set[Real] = set()
109110
# Contains the error in the estimate of the
110111
# mean at each point x in the form {x0: error(x0), ...}
111112
self.error: ItemSortedDict[Real, float] = decreasing_dict()
@@ -126,7 +127,7 @@ def min_samples_per_point(self) -> int:
126127
return 0
127128
return min(self._number_samples.values())
128129

129-
def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
130+
def ask(self, n: int, tell_pending: bool = True) -> tuple[Points, list[float]]:
130131
"""Return 'n' points that are expected to maximally reduce the loss."""
131132
# If some point is undersampled, resample it
132133
if len(self._undersampled_points):
@@ -155,7 +156,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
155156

156157
return points, loss_improvements
157158

158-
def _ask_for_more_samples(self, x: Real, n: int) -> Tuple[Points, List[float]]:
159+
def _ask_for_more_samples(self, x: Real, n: int) -> tuple[Points, list[float]]:
159160
"""When asking for n points, the learner returns n times an existing point
160161
to be resampled, since in general n << min_samples and this point will
161162
need to be resampled many more times"""
@@ -174,7 +175,7 @@ def _ask_for_more_samples(self, x: Real, n: int) -> Tuple[Points, List[float]]:
174175
loss_improvements = [loss_improvement / n] * n
175176
return points, loss_improvements
176177

177-
def _ask_for_new_point(self, n: int) -> Tuple[Points, List[float]]:
178+
def _ask_for_new_point(self, n: int) -> tuple[Points, list[float]]:
178179
"""When asking for n new points, the learner returns n times a single
179180
new point, since in general n << min_samples and this point will need
180181
to be resampled many more times"""
@@ -388,7 +389,7 @@ def tell_many(self, xs: Points, ys: Sequence[Real]) -> None:
388389
# simultaneously, before we move on to a new x
389390
self.tell_many_at_point(x, seed_y_mapping)
390391

391-
def tell_many_at_point(self, x: Real, seed_y_mapping: Dict[int, Real]) -> None:
392+
def tell_many_at_point(self, x: Real, seed_y_mapping: dict[int, Real]) -> None:
392393
"""Tell the learner about many samples at a certain location x.
393394
394395
Parameters

adaptive/learner/balancing_learner.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,13 @@
1+
from __future__ import annotations
2+
13
import itertools
24
import numbers
35
from collections import defaultdict
46
from collections.abc import Iterable
57
from contextlib import suppress
68
from functools import partial
79
from operator import itemgetter
8-
from typing import (
9-
Any,
10-
Callable,
11-
Dict,
12-
List,
13-
Literal,
14-
Optional,
15-
Sequence,
16-
Set,
17-
Tuple,
18-
Union,
19-
)
10+
from typing import Any, Callable, Dict, Literal, Sequence, Tuple, Union
2011

2112
import numpy as np
2213

@@ -25,7 +16,7 @@
2516
from adaptive.utils import cache_latest, named_product, restore
2617

2718

28-
def dispatch(child_functions: List[Callable], arg: Any) -> Union[Any]:
19+
def dispatch(child_functions: list[Callable], arg: Any) -> Any:
2920
index, x = arg
3021
return child_functions[index](x)
3122

@@ -91,9 +82,9 @@ class BalancingLearner(BaseLearner):
9182

9283
def __init__(
9384
self,
94-
learners: List[BaseLearner],
85+
learners: list[BaseLearner],
9586
*,
96-
cdims: Optional[CDIMS_TYPE] = None,
87+
cdims: CDIMS_TYPE | None = None,
9788
strategy: STRATEGY_TYPE = "loss_improvements",
9889
) -> None:
9990
self.learners = learners
@@ -116,14 +107,14 @@ def __init__(
116107
self.strategy: STRATEGY_TYPE = strategy
117108

118109
@property
119-
def data(self) -> Dict[Tuple[int, Any], Any]:
110+
def data(self) -> dict[tuple[int, Any], Any]:
120111
data = {}
121112
for i, l in enumerate(self.learners):
122113
data.update({(i, p): v for p, v in l.data.items()})
123114
return data
124115

125116
@property
126-
def pending_points(self) -> Set[Tuple[int, Any]]:
117+
def pending_points(self) -> set[tuple[int, Any]]:
127118
pending_points = set()
128119
for i, l in enumerate(self.learners):
129120
pending_points.update({(i, p) for p in l.pending_points})
@@ -173,7 +164,7 @@ def strategy(self, strategy: STRATEGY_TYPE) -> None:
173164

174165
def _ask_and_tell_based_on_loss_improvements(
175166
self, n: int
176-
) -> Tuple[List[Tuple[int, Any]], List[float]]:
167+
) -> tuple[list[tuple[int, Any]], list[float]]:
177168
selected = [] # tuples ((learner_index, point), loss_improvement)
178169
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
179170
for _ in range(n):
@@ -198,7 +189,7 @@ def _ask_and_tell_based_on_loss_improvements(
198189

199190
def _ask_and_tell_based_on_loss(
200191
self, n: int
201-
) -> Tuple[List[Tuple[int, Any]], List[float]]:
192+
) -> tuple[list[tuple[int, Any]], list[float]]:
202193
selected = [] # tuples ((learner_index, point), loss_improvement)
203194
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
204195
for _ in range(n):
@@ -221,7 +212,7 @@ def _ask_and_tell_based_on_loss(
221212

222213
def _ask_and_tell_based_on_npoints(
223214
self, n: numbers.Integral
224-
) -> Tuple[List[Tuple[numbers.Integral, Any]], List[float]]:
215+
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
225216
selected = [] # tuples ((learner_index, point), loss_improvement)
226217
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
227218
for _ in range(n):
@@ -239,7 +230,7 @@ def _ask_and_tell_based_on_npoints(
239230

240231
def _ask_and_tell_based_on_cycle(
241232
self, n: int
242-
) -> Tuple[List[Tuple[numbers.Integral, Any]], List[float]]:
233+
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
243234
points, loss_improvements = [], []
244235
for _ in range(n):
245236
index = next(self._cycle)
@@ -252,7 +243,7 @@ def _ask_and_tell_based_on_cycle(
252243

253244
def ask(
254245
self, n: int, tell_pending: bool = True
255-
) -> Tuple[List[Tuple[numbers.Integral, Any]], List[float]]:
246+
) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]:
256247
"""Chose points for learners."""
257248
if n == 0:
258249
return [], []
@@ -263,20 +254,20 @@ def ask(
263254
else:
264255
return self._ask_and_tell(n)
265256

266-
def tell(self, x: Tuple[numbers.Integral, Any], y: Any) -> None:
257+
def tell(self, x: tuple[numbers.Integral, Any], y: Any) -> None:
267258
index, x = x
268259
self._ask_cache.pop(index, None)
269260
self._loss.pop(index, None)
270261
self._pending_loss.pop(index, None)
271262
self.learners[index].tell(x, y)
272263

273-
def tell_pending(self, x: Tuple[numbers.Integral, Any]) -> None:
264+
def tell_pending(self, x: tuple[numbers.Integral, Any]) -> None:
274265
index, x = x
275266
self._ask_cache.pop(index, None)
276267
self._loss.pop(index, None)
277268
self.learners[index].tell_pending(x)
278269

279-
def _losses(self, real: bool = True) -> List[float]:
270+
def _losses(self, real: bool = True) -> list[float]:
280271
losses = []
281272
loss_dict = self._loss if real else self._pending_loss
282273

@@ -294,8 +285,8 @@ def loss(self, real: bool = True) -> float:
294285

295286
def plot(
296287
self,
297-
cdims: Optional[CDIMS_TYPE] = None,
298-
plotter: Optional[Callable[[BaseLearner], Any]] = None,
288+
cdims: CDIMS_TYPE | None = None,
289+
plotter: Callable[[BaseLearner], Any] | None = None,
299290
dynamic: bool = True,
300291
):
301292
"""Returns a DynamicMap with sliders.
@@ -380,9 +371,9 @@ def from_product(
380371
cls,
381372
f,
382373
learner_type: BaseLearner,
383-
learner_kwargs: Dict[str, Any],
384-
combos: Dict[str, Sequence[Any]],
385-
) -> "BalancingLearner":
374+
learner_kwargs: dict[str, Any],
375+
combos: dict[str, Sequence[Any]],
376+
) -> BalancingLearner:
386377
"""Create a `BalancingLearner` with learners of all combinations of
387378
named variables’ values. The `cdims` will be set correctly, so calling
388379
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -431,7 +422,7 @@ def from_product(
431422

432423
def save(
433424
self,
434-
fname: Union[Callable[[BaseLearner], str], Sequence[str]],
425+
fname: Callable[[BaseLearner], str] | Sequence[str],
435426
compress: bool = True,
436427
) -> None:
437428
"""Save the data of the child learners into pickle files
@@ -473,7 +464,7 @@ def save(
473464

474465
def load(
475466
self,
476-
fname: Union[Callable[[BaseLearner], str], Sequence[str]],
467+
fname: Callable[[BaseLearner], str] | Sequence[str],
477468
compress: bool = True,
478469
) -> None:
479470
"""Load the data of the child learners from pickle files
@@ -499,20 +490,20 @@ def load(
499490
for l in self.learners:
500491
l.load(fname(l), compress=compress)
501492

502-
def _get_data(self) -> List[Any]:
493+
def _get_data(self) -> list[Any]:
503494
return [l._get_data() for l in self.learners]
504495

505-
def _set_data(self, data: List[Any]):
496+
def _set_data(self, data: list[Any]):
506497
for l, _data in zip(self.learners, data):
507498
l._set_data(_data)
508499

509-
def __getstate__(self) -> Tuple[List[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
500+
def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
510501
return (
511502
self.learners,
512503
self._cdims_default,
513504
self.strategy,
514505
)
515506

516-
def __setstate__(self, state: Tuple[List[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
507+
def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
517508
learners, cdims, strategy = state
518509
self.__init__(learners, cdims=cdims, strategy=strategy)

adaptive/learner/base_learner.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import abc
24
from contextlib import suppress
35
from copy import deepcopy
4-
from typing import Any, Callable, Dict
6+
from typing import Any, Callable
57

68
from adaptive.utils import _RequireAttrsABCMeta, load, save
79

@@ -115,12 +117,10 @@ def tell_many(self, xs: Any, ys: Any) -> None:
115117
def tell_pending(self, x):
116118
"""Tell the learner that 'x' has been requested such
117119
that it's not suggested again."""
118-
pass
119120

120121
@abc.abstractmethod
121122
def remove_unfinished(self):
122123
"""Remove uncomputed data from the learner."""
123-
pass
124124

125125
@abc.abstractmethod
126126
def loss(self, real: bool = True):
@@ -147,7 +147,6 @@ def ask(self, n: int, tell_pending: bool = True):
147147
`pending_points`. Set this to False if you do not
148148
want to modify the state of the learner.
149149
"""
150-
pass
151150

152151
@abc.abstractmethod
153152
def _get_data(self):
@@ -196,8 +195,8 @@ def load(self, fname: str, compress: bool = True) -> None:
196195
data = load(fname, compress)
197196
self._set_data(data)
198197

199-
def __getstate__(self) -> Dict[str, Any]:
198+
def __getstate__(self) -> dict[str, Any]:
200199
return deepcopy(self.__dict__)
201200

202-
def __setstate__(self, state: Dict[str, Any]) -> None:
201+
def __setstate__(self, state: dict[str, Any]) -> None:
203202
self.__dict__ = state

0 commit comments

Comments
 (0)