Skip to content

Commit 7c34051

Browse files
committed
make methods private in the LearnerND, closes #85
1 parent 190248c commit 7c34051

File tree

1 file changed

+15
-28
lines changed

1 file changed

+15
-28
lines changed

adaptive/learner/learnerND.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,10 @@ def vdim(self):
227227
def bounds_are_done(self):
228228
return all(p in self.data for p in self._bounds_points)
229229

230-
def ip(self):
230+
def _ip(self):
231231
"""A `scipy.interpolate.LinearNDInterpolator` instance
232232
containing the learner's data."""
233-
# XXX: take our own triangulation into account when generating the ip
233+
# XXX: take our own triangulation into account when generating the _ip
234234
return interpolate.LinearNDInterpolator(self.points, self.values)
235235

236236
@property
@@ -242,7 +242,7 @@ def tri(self):
242242

243243
try:
244244
self._tri = Triangulation(self.points)
245-
self.update_losses(set(), self._tri.simplices)
245+
self._update_losses(set(), self._tri.simplices)
246246
return self._tri
247247
except ValueError:
248248
# A ValueError is raised if we do not have enough points or
@@ -283,7 +283,7 @@ def tell(self, point, value):
283283
simplex = None
284284
to_delete, to_add = tri.add_point(
285285
point, simplex, transform=self._transform)
286-
self.update_losses(to_delete, to_add)
286+
self._update_losses(to_delete, to_add)
287287

288288
def _simplex_exists(self, simplex):
289289
simplex = tuple(sorted(simplex))
@@ -441,7 +441,7 @@ def _ask(self):
441441

442442
return self._ask_best_point() # O(log N)
443443

444-
def update_losses(self, to_delete: set, to_add: set):
444+
def _update_losses(self, to_delete: set, to_add: set):
445445
# XXX: add the points outside the triangulation to this as well
446446
pending_points_unbound = set()
447447

@@ -455,7 +455,7 @@ def update_losses(self, to_delete: set, to_add: set):
455455
if p not in self.data)
456456

457457
for simplex in to_add:
458-
loss = self.compute_loss(simplex)
458+
loss = self._compute_loss(simplex)
459459
self._losses[simplex] = loss
460460

461461
for p in pending_points_unbound:
@@ -469,7 +469,7 @@ def update_losses(self, to_delete: set, to_add: set):
469469
self._update_subsimplex_losses(
470470
simplex, self._subtriangulations[simplex].simplices)
471471

472-
def compute_loss(self, simplex):
472+
def _compute_loss(self, simplex):
473473
# get the loss
474474
vertices = self.tri.get_vertices(simplex)
475475
values = [self.data[tuple(v)] for v in vertices]
@@ -481,7 +481,7 @@ def compute_loss(self, simplex):
481481
# compute the loss on the scaled simplex
482482
return float(self.loss_per_simplex(vertices, values))
483483

484-
def recompute_all_losses(self):
484+
def _recompute_all_losses(self):
485485
"""Recompute all losses and pending losses."""
486486
# amortized O(N) complexity
487487
if self.tri is None:
@@ -492,7 +492,7 @@ def recompute_all_losses(self):
492492

493493
# recompute all losses
494494
for simplex in self.tri.simplices:
495-
loss = self.compute_loss(simplex)
495+
loss = self._compute_loss(simplex)
496496
self._losses[simplex] = loss
497497

498498
# now distribute it around the the children if they are present
@@ -543,27 +543,14 @@ def _update_range(self, new_output):
543543
scale_factor = np.max(np.nan_to_num(self._scale / self._old_scale))
544544
if scale_factor > self._recompute_losses_factor:
545545
self._old_scale = self._scale
546-
self.recompute_all_losses()
546+
self._recompute_all_losses()
547547
return True
548548
return False
549549

550-
def losses(self):
551-
"""Get the losses of each simplex in the current triangulation, as dict
552-
553-
Returns
554-
-------
555-
losses : dict
556-
the key is a simplex, the value is the loss of this simplex
557-
"""
558-
# XXX could be a property
559-
if self.tri is None:
560-
return dict()
561-
562-
return self._losses
563-
564550
@cache_latest
565551
def loss(self, real=True):
566-
losses = self.losses() # XXX: compute pending loss if real == False
552+
# XXX: compute pending loss if real == False
553+
losses = self._losses if self.tri is not None else dict()
567554
return max(losses.values()) if losses else float('inf')
568555

569556
def remove_unfinished(self):
@@ -607,7 +594,7 @@ def plot(self, n=None, tri_alpha=0):
607594
xs = ys = np.linspace(0, 1, n)
608595
xs = xs * (x[1] - x[0]) + x[0]
609596
ys = ys * (y[1] - y[0]) + y[0]
610-
z = self.ip()(xs[:, None], ys[None, :]).squeeze()
597+
z = self._ip()(xs[:, None], ys[None, :]).squeeze()
611598

612599
im = hv.Image(np.rot90(z), bounds=lbrt)
613600

@@ -656,7 +643,7 @@ def plot_slice(self, cut_mapping, n=None):
656643
for i in range(self.ndim)]
657644
ind = next(i for i in range(self.ndim) if i not in cut_mapping)
658645
x = values[ind]
659-
y = self.ip()(*values)
646+
y = self._ip()(*values)
660647
p = hv.Path((x, y))
661648

662649
# Plot with 5% margins such that the boundary points are visible
@@ -686,7 +673,7 @@ def plot_slice(self, cut_mapping, n=None):
686673
lbrt = np.reshape(lbrt, (2, 2)).T.flatten().tolist()
687674

688675
if len(self.data) >= 4:
689-
z = self.ip()(*values).squeeze()
676+
z = self._ip()(*values).squeeze()
690677
im = hv.Image(np.rot90(z), bounds=lbrt)
691678
else:
692679
im = hv.Image([], bounds=lbrt)

0 commit comments

Comments
 (0)