Skip to content

Commit ede9582

Browse files
authored
Merge pull request #222 from python-adaptive/attr_checking_base_class
add _RequireAttrsABCMeta and make the BaseLearner use it
2 parents 187f88f + 1474a5d commit ede9582

File tree

6 files changed

+66
-29
lines changed

6 files changed

+66
-29
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,24 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
9090

9191
self.strategy = strategy
9292

93+
@property
94+
def data(self):
95+
data = {}
96+
for i, l in enumerate(self.learners):
97+
data.update({(i, p): v for p, v in l.data.items()})
98+
return data
99+
100+
@property
101+
def pending_points(self):
102+
pending_points = set()
103+
for i, l in enumerate(self.learners):
104+
pending_points.update({(i, p) for p in l.pending_points})
105+
return pending_points
106+
107+
@property
108+
def npoints(self):
109+
return sum(l.npoints for l in self.learners)
110+
93111
@property
94112
def strategy(self):
95113
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or

adaptive/learner/base_learner.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from contextlib import suppress
55
from copy import deepcopy
66

7-
from adaptive.utils import load, save
7+
from adaptive.utils import load, save, _RequireAttrsABCMeta
88

99

1010
def uses_nth_neighbors(n):
@@ -61,30 +61,31 @@ def _wrapped(loss_per_interval):
6161
return _wrapped
6262

6363

64-
class BaseLearner(metaclass=abc.ABCMeta):
64+
class BaseLearner(metaclass=_RequireAttrsABCMeta):
6565
"""Base class for algorithms for learning a function 'f: X → Y'.
6666
6767
Attributes
6868
----------
6969
function : callable: X → Y
70-
The function to learn.
70+
The function to learn. A subclass of BaseLearner might modify
71+
the user's supplied function.
7172
data : dict: X → Y
7273
`function` evaluated at certain points.
73-
The values can be 'None', which indicates that the point
74-
will be evaluated, but that we do not have the result yet.
75-
npoints : int, optional
76-
The number of evaluated points that have been added to the learner.
77-
Subclasses do not *have* to implement this attribute.
78-
pending_points : set, optional
74+
pending_points : set
7975
Points that have been requested but have not been evaluated yet.
80-
Subclasses do not *have* to implement this attribute.
76+
npoints : int
77+
The number of evaluated points that have been added to the learner.
8178
8279
Notes
8380
-----
8481
Subclasses may define a ``plot`` method that takes no parameters
8582
and returns a holoviews plot.
8683
"""
8784

85+
data: dict
86+
npoints: int
87+
pending_points: set
88+
8889
def tell(self, x, y):
8990
"""Tell the learner about a single value.
9091

adaptive/learner/integrator_learner.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class _Interval:
100100
The parent interval.
101101
children : list of `_Interval`s
102102
The intervals resulting from a split.
103-
done_points : dict
103+
data : dict
104104
A dictionary with the x-values and y-values: `{x1: y1, x2: y2 ...}`.
105105
done : bool
106106
The integral and the error for the interval has been calculated.
@@ -133,15 +133,15 @@ class _Interval:
133133
"ndiv",
134134
"parent",
135135
"children",
136-
"done_points",
136+
"data",
137137
"done_leaves",
138138
"depth_complete",
139139
"removed",
140140
]
141141

142142
def __init__(self, a, b, depth, rdepth):
143143
self.children = []
144-
self.done_points = {}
144+
self.data = {}
145145
self.a = a
146146
self.b = b
147147
self.depth = depth
@@ -172,9 +172,9 @@ def T(self):
172172

173173
def refinement_complete(self, depth):
174174
"""The interval has all the y-values to calculate the intergral."""
175-
if len(self.done_points) < ns[depth]:
175+
if len(self.data) < ns[depth]:
176176
return False
177-
return all(p in self.done_points for p in self.points(depth))
177+
return all(p in self.data for p in self.points(depth))
178178

179179
def points(self, depth=None):
180180
if depth is None:
@@ -255,7 +255,7 @@ def complete_process(self, depth):
255255
assert self.depth_complete is None or self.depth_complete == depth - 1
256256
self.depth_complete = depth
257257

258-
fx = [self.done_points[k] for k in self.points(depth)]
258+
fx = [self.data[k] for k in self.points(depth)]
259259
self.fx = np.array(fx)
260260
force_split = False # This may change when refining
261261

@@ -375,7 +375,7 @@ def __init__(self, function, bounds, tol):
375375
self.tol = tol
376376
self.max_ivals = 1000
377377
self.priority_split = []
378-
self.done_points = {}
378+
self.data = {}
379379
self.pending_points = set()
380380
self._stack = []
381381
self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth")))
@@ -391,13 +391,13 @@ def approximating_intervals(self):
391391
def tell(self, point, value):
392392
if point not in self.x_mapping:
393393
raise ValueError(f"Point {point} doesn't belong to any interval")
394-
self.done_points[point] = value
394+
self.data[point] = value
395395
self.pending_points.discard(point)
396396

397397
# Select the intervals that have this point
398398
ivals = self.x_mapping[point]
399399
for ival in ivals:
400-
ival.done_points[point] = value
400+
ival.data[point] = value
401401

402402
if ival.depth_complete is None:
403403
from_depth = 0 if ival.parent is not None else 2
@@ -438,8 +438,8 @@ def add_ival(self, ival):
438438
for x in ival.points():
439439
# Update the mappings
440440
self.x_mapping[x].add(ival)
441-
if x in self.done_points:
442-
self.tell(x, self.done_points[x])
441+
if x in self.data:
442+
self.tell(x, self.data[x])
443443
elif x not in self.pending_points:
444444
self.pending_points.add(x)
445445
self._stack.append(x)
@@ -518,7 +518,7 @@ def _fill_stack(self):
518518
@property
519519
def npoints(self):
520520
"""Number of evaluated points."""
521-
return len(self.done_points)
521+
return len(self.data)
522522

523523
@property
524524
def igral(self):
@@ -552,11 +552,9 @@ def loss(self, real=True):
552552
def plot(self):
553553
hv = ensure_holoviews()
554554
ivals = sorted(self.ivals, key=attrgetter("a"))
555-
if not self.done_points:
555+
if not self.data:
556556
return hv.Path([])
557-
xs, ys = zip(
558-
*[(x, y) for ival in ivals for x, y in sorted(ival.done_points.items())]
559-
)
557+
xs, ys = zip(*[(x, y) for ival in ivals for x, y in sorted(ival.data.items())])
560558
return hv.Path((xs, ys))
561559

562560
def _get_data(self):
@@ -565,7 +563,7 @@ def _get_data(self):
565563

566564
return (
567565
self.priority_split,
568-
self.done_points,
566+
self.data,
569567
self.pending_points,
570568
self._stack,
571569
x_mapping,
@@ -574,7 +572,7 @@ def _get_data(self):
574572
)
575573

576574
def _set_data(self, data):
577-
self.priority_split, self.done_points, self.pending_points, self._stack, x_mapping, self.ivals, self.first_ival = (
575+
self.priority_split, self.data, self.pending_points, self._stack, x_mapping, self.ivals, self.first_ival = (
578576
data
579577
)
580578

adaptive/learner/skopt_learner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ class SKOptLearner(Optimizer, BaseLearner):
2626
def __init__(self, function, **kwargs):
2727
self.function = function
2828
self.pending_points = set()
29+
self.data = {}
2930
super().__init__(**kwargs)
3031

3132
def tell(self, x, y, fit=True):
3233
self.pending_points.discard(x)
34+
self.data[x] = y
3335
super().tell([x], y, fit)
3436

3537
def tell_pending(self, x):

adaptive/tests/test_cquad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def test_tell_in_random_order(first_add_33=False):
188188
learners.append(learner)
189189

190190
# Check whether the points of the learners are identical
191-
assert set(learners[0].done_points) == set(learners[1].done_points)
191+
assert set(learners[0].data) == set(learners[1].data)
192192

193193
# Test whether approximating_intervals gives a complete set of intervals
194194
for learner in learners:

adaptive/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3+
import abc
34
import functools
45
import gzip
56
import os
@@ -67,3 +68,20 @@ def decorator(method):
6768
return functools.wraps(other)(method)
6869

6970
return decorator
71+
72+
73+
class _RequireAttrsABCMeta(abc.ABCMeta):
74+
def __call__(self, *args, **kwargs):
75+
obj = super().__call__(*args, **kwargs)
76+
for name, type_ in obj.__annotations__.items():
77+
try:
78+
x = getattr(obj, name)
79+
except AttributeError:
80+
raise AttributeError(
81+
f"Required attribute {name} not set in __init__."
82+
) from None
83+
else:
84+
if not isinstance(x, type_):
85+
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
86+
raise TypeError(msg)
87+
return obj

0 commit comments

Comments
 (0)