Skip to content

Commit 4c10cc3

Browse files
authored
Merge pull request #437 from aai-institute/feature/filter-converged
Stop updating indices as soon as they converge in semivalue computations
2 parents 60d8aef + 491dbd1 commit 4c10cc3

File tree

7 files changed

+175
-36
lines changed

7 files changed

+175
-36
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
- No longer using docker within tests to start a memcached server
66
[PR #444](https://github.com/aai-institute/pyDVL/pull/444)
7+
- Faster semi-value computation with per-index check of stopping criteria (optional)
8+
[PR #437](https://github.com/aai-institute/pyDVL/pull/437)
79
- Improvements and fixes to notebooks
810
[PR #436](https://github.com/aai-institute/pyDVL/pull/436)
911
- Fix initialization of `data_names` in `ValuationResult.zeros()`

src/pydvl/reporting/plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def plot_ci_array(
104104
means = np.mean(data, axis=0)
105105
variances = np.var(data, axis=0, ddof=1)
106106

107-
dummy: ValuationResult[np.int_, np.object_] = ValuationResult(
107+
dummy = ValuationResult[np.int_, np.object_](
108108
algorithm="dummy",
109109
values=means,
110110
variances=variances,

src/pydvl/value/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __len__(self) -> int:
200200
return len(self._outer_indices)
201201

202202
def __str__(self):
203-
return f"{self.__class__.__name__}"
203+
return self.__class__.__name__
204204

205205
def __repr__(self):
206206
return f"{self.__class__.__name__}({self._indices}, {self._outer_indices})"

src/pydvl/value/semivalues.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
import warnings
9393
from enum import Enum
9494
from itertools import islice
95-
from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast
95+
from typing import Iterable, List, Optional, Protocol, Tuple, Type, cast
9696

9797
import scipy as sp
9898
from deprecate import deprecated
@@ -143,7 +143,7 @@ def __call__(self, n: int, k: int) -> float:
143143

144144

145145
def _marginal(
146-
u: Utility, coefficient: SVCoefficient, samples: Collection[SampleT]
146+
u: Utility, coefficient: SVCoefficient, samples: Iterable[SampleT]
147147
) -> Tuple[MarginalT, ...]:
148148
"""Computation of marginal utility. This is a helper function for
149149
[compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues].
@@ -186,6 +186,7 @@ def compute_generic_semivalues(
186186
done: StoppingCriterion,
187187
*,
188188
batch_size: int = 1,
189+
skip_converged: bool = False,
189190
n_jobs: int = 1,
190191
config: ParallelConfig = ParallelConfig(),
191192
progress: bool = False,
@@ -198,6 +199,15 @@ def compute_generic_semivalues(
198199
coefficient: The semi-value coefficient
199200
done: Stopping criterion.
200201
batch_size: Number of marginal evaluations per single parallel job.
202+
skip_converged: Whether to skip marginal evaluations for indices that
203+
have already converged. **CAUTION**: This is only entirely safe if
204+
the stopping criterion is [MaxUpdates][pydvl.value.stopping.MaxUpdates].
205+
For any other stopping criterion, the convergence status of indices
206+
may change during the computation, or they may be marked as having
207+
converged even though in fact the estimated values are far from the
208+
true values (e.g. for
209+
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError],
210+
you will probably have to carefully adjust the threshold).
201211
n_jobs: Number of parallel jobs to use.
202212
config: Object configuring parallel computation, with cluster
203213
address, number of cpus, etc.
@@ -262,16 +272,33 @@ def compute_generic_semivalues(
262272

263273
# Ensure that we always have n_submitted_jobs running
264274
try:
265-
for _ in range(n_submitted_jobs - len(pending)):
275+
while len(pending) < n_submitted_jobs:
266276
samples = tuple(islice(sampler_it, batch_size))
267277
if len(samples) == 0:
268278
raise StopIteration
269279

270-
pending.add(
271-
executor.submit(
272-
_marginal, u=u, coefficient=correction, samples=samples
280+
# Filter out samples for indices that have already converged
281+
filtered_samples = samples
282+
if skip_converged and len(done.converged) > 0:
283+
# cloudpickle can't pickle this on python 3.8:
284+
# filtered_samples = filter(
285+
# lambda t: not done.converged[t[0]], samples
286+
# )
287+
filtered_samples = tuple(
288+
(idx, sample)
289+
for idx, sample in samples
290+
if not done.converged[idx]
291+
)
292+
293+
if filtered_samples:
294+
pending.add(
295+
executor.submit(
296+
_marginal,
297+
u=u,
298+
coefficient=correction,
299+
samples=filtered_samples,
300+
)
273301
)
274-
)
275302
except StopIteration:
276303
if len(pending) == 0:
277304
return result

src/pydvl/value/stopping.py

Lines changed: 125 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,94 @@
1-
"""
1+
r"""
22
Stopping criteria for value computations.
33
4-
This module provides a basic set of stopping criteria, like [MaxUpdates][pydvl.value.stopping.MaxUpdates],
5-
[MaxTime][pydvl.value.stopping.MaxTime], or [HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others.
6-
These can behave in different ways depending on the context.
7-
For example, [MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
4+
This module provides a basic set of stopping criteria, like
5+
[MaxUpdates][pydvl.value.stopping.MaxUpdates],
6+
[MaxTime][pydvl.value.stopping.MaxTime], or
7+
[HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others. These
8+
can behave in different ways depending on the context. For example,
9+
[MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
810
the number of updates to values, which depending on the algorithm may mean a
911
different number of utility evaluations or imply other computations like solving
1012
a linear or quadratic program.
1113
12-
# Creating stopping criteria
14+
Stopping criteria are callables that are evaluated on a
15+
[ValuationResult][pydvl.value.result.ValuationResult] and return a
16+
[Status][pydvl.utils.status.Status] object. They can be combined using boolean
17+
operators.
18+
19+
## How convergence is determined
20+
21+
Most stopping criteria keep track of the convergence of each index separately
22+
but make global decisions based on the overall convergence of some fraction of
23+
all indices. For example, if we have a stopping criterion that checks whether
24+
the standard error of 90% of values is below a threshold, then methods will keep
25+
updating **all** indices until 90% of them have converged, irrespective of the
26+
quality of the individual estimates, and *without freezing updates* for indices
27+
along the way as values individually attain low standard error.
28+
29+
This has some practical implications, because some values do tend to converge
30+
sooner than others. For example, assume we use the criterion
31+
`AbsoluteStandardError(0.02) | MaxUpdates(1000)`. Then values close to 0 might
32+
be marked as "converged" rather quickly because they fulfill the first
33+
criterion, say after 20 iterations, despite being poor estimates. Because other
34+
indices take much longer to have low standard error and the criterion is a
35+
global check, the "converged" ones keep being updated and end up being good
36+
estimates. In this case, this has been beneficial, but one might not wish for
37+
converged values to be updated, if one is sure that the criterion is adequate
38+
for individual values.
39+
40+
[Semi-value methods][pydvl.value.semivalues] include a parameter
41+
`skip_converged` that allows to skip the computation of values that have
42+
converged. The way to avoid doing this too early is to use a more stringent
43+
check, e.g. `AbsoluteStandardError(1e-3) | MaxUpdates(1000)`. With
44+
`skip_converged=True` this check can still take less time than the first one,
45+
despite requiring more iterations for some indices.
46+
47+
48+
## Choosing a stopping criterion
49+
50+
The choice of a stopping criterion greatly depends on the algorithm and the
51+
context. A safe bet is to combine a [MaxUpdates][pydvl.value.stopping.MaxUpdates]
52+
or a [MaxTime][pydvl.value.stopping.MaxTime] with a
53+
[HistoryDeviation][pydvl.value.stopping.HistoryDeviation] or an
54+
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError]. The former
55+
will ensure that the computation does not run for too long, while the latter
56+
will try to achieve results that are stable enough. Note however that if the
57+
threshold is too strict, one will always end up running until a maximum number
58+
of iterations or time. Also keep in mind that different values converge at
59+
different times, so you might want to use tight thresholds and `skip_converged`
60+
as described above for semi-values.
61+
62+
63+
??? Example
64+
```python
65+
from pydvl.value import AbsoluteStandardError, MaxUpdates, compute_banzhaf_semivalues
66+
67+
utility = ... # some utility object
68+
criterion = AbsoluteStandardError(threshold=1e-3, burn_in=32) | MaxUpdates(1000)
69+
values = compute_banzhaf_semivalues(
70+
utility,
71+
criterion,
72+
skip_converged=True, # skip values that have converged (CAREFUL!)
73+
)
74+
```
75+
This will compute the Banzhaf semivalues for `utility` until either the
76+
absolute standard error is below `1e-3` or `1000` updates have been
77+
performed. The `burn_in` parameter is used to discard the first `32` updates
78+
from the computation of the standard error. The `skip_converged` parameter
79+
is used to avoid computing more marginals for indices that have converged,
80+
which is useful if
81+
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError] is met
82+
before [MaxUpdates][pydvl.value.stopping.MaxUpdates] for some indices.
83+
84+
!!! Warning
85+
Be careful not to reuse the same stopping criterion for different
86+
computations. The object has state and will not be reset between calls to
87+
value computation methods. If you need to reuse the same criterion, you
88+
should create a new instance.
89+
90+
91+
## Creating stopping criteria
1392
1493
The easiest way is to declare a function implementing the interface
1594
[StoppingCriterionCallable][pydvl.value.stopping.StoppingCriterionCallable] and
@@ -18,19 +97,19 @@
1897
that can be composed with other stopping criteria.
1998
2099
Alternatively, and in particular if reporting of completion is required, one can
21-
inherit from this class and implement the abstract methods
22-
[_check][pydvl.value.stopping.StoppingCriterion._check] and
100+
inherit from this class and implement the abstract methods `_check` and
23101
[completion][pydvl.value.stopping.StoppingCriterion.completion].
24102
25-
# Composing stopping criteria
103+
## Combining stopping criteria
26104
27105
Objects of type [StoppingCriterion][pydvl.value.stopping.StoppingCriterion] can
28-
be composed with the binary operators `&` (*and*), and `|` (*or*), following the
106+
be combined with the binary operators `&` (*and*), and `|` (*or*), following the
29107
truth tables of [Status][pydvl.utils.status.Status]. The unary operator `~`
30108
(*not*) is also supported. See
31109
[StoppingCriterion][pydvl.value.stopping.StoppingCriterion] for details on how
32110
these operations affect the behavior of the stopping criteria.
33111
112+
34113
## References
35114
36115
[^1]: <a name="ghorbani_data_2019"></a>Ghorbani, A., Zou, J., 2019.
@@ -163,6 +242,15 @@ def converged(self) -> NDArray[np.bool_]:
163242

164243
@property
165244
def name(self):
245+
log = logging.getLogger(__name__)
246+
# This string for the benefit of deprecation searches:
247+
# remove_in="0.8.0"
248+
log.warning(
249+
"The `name` attribute of `StoppingCriterion` is deprecated and will be removed in 0.8.0. "
250+
)
251+
return getattr(self, "_name", type(self).__name__)
252+
253+
def __str__(self):
166254
return type(self).__name__
167255

168256
def __call__(self, result: ValuationResult) -> Status:
@@ -182,23 +270,23 @@ def __and__(self, other: "StoppingCriterion") -> "StoppingCriterion":
182270
fun=lambda result: self._check(result) & other._check(result),
183271
converged=lambda: self.converged & other.converged,
184272
completion=lambda: min(self.completion(), other.completion()),
185-
name=f"Composite StoppingCriterion: {self.name} AND {other.name}",
273+
name=f"Composite StoppingCriterion: {str(self)} AND {str(other)}",
186274
)(modify_result=self.modify_result or other.modify_result)
187275

188276
def __or__(self, other: "StoppingCriterion") -> "StoppingCriterion":
189277
return make_criterion(
190278
fun=lambda result: self._check(result) | other._check(result),
191279
converged=lambda: self.converged | other.converged,
192280
completion=lambda: max(self.completion(), other.completion()),
193-
name=f"Composite StoppingCriterion: {self.name} OR {other.name}",
281+
name=f"Composite StoppingCriterion: {str(self)} OR {str(other)}",
194282
)(modify_result=self.modify_result or other.modify_result)
195283

196284
def __invert__(self) -> "StoppingCriterion":
197285
return make_criterion(
198286
fun=lambda result: ~self._check(result),
199287
converged=lambda: ~self.converged,
200288
completion=lambda: 1 - self.completion(),
201-
name=f"Composite StoppingCriterion: NOT {self.name}",
289+
name=f"Composite StoppingCriterion: NOT {str(self)}",
202290
)(modify_result=self.modify_result)
203291

204292

@@ -239,8 +327,7 @@ def converged(self) -> NDArray[np.bool_]:
239327
return super().converged
240328
return converged()
241329

242-
@property
243-
def name(self):
330+
def __str__(self):
244331
return self._name
245332

246333
def completion(self) -> float:
@@ -254,13 +341,13 @@ def completion(self) -> float:
254341
class AbsoluteStandardError(StoppingCriterion):
255342
r"""Determine convergence based on the standard error of the values.
256343
257-
If $s_i$ is the standard error for datum $i$ and $v_i$ its value, then this
258-
criterion returns [Converged][pydvl.utils.status.Status] if
259-
$s_i < \epsilon$ for all $i$ and a threshold value $\epsilon \gt 0$.
344+
If $s_i$ is the standard error for datum $i$, then this criterion returns
345+
[Converged][pydvl.utils.status.Status] if $s_i < \epsilon$ for all $i$ and a
346+
threshold value $\epsilon \gt 0$.
260347
261348
Args:
262349
threshold: A value is considered to have converged if the standard
263-
error is below this value. A way of choosing it is to pick some
350+
error is below this threshold. A way of choosing it is to pick some
264351
percentage of the range of the values. For Shapley values this is
265352
the difference between the maximum and minimum of the utility
266353
function (to see this substitute the maximum and minimum values of
@@ -270,7 +357,7 @@ class AbsoluteStandardError(StoppingCriterion):
270357
burn_in: The number of iterations to ignore before checking for
271358
convergence. This is required because computations typically start
272359
with zero variance, as a result of using
273-
[empty()][pydvl.value.result.ValuationResult.empty]. The default is
360+
[zeros()][pydvl.value.result.ValuationResult.zeros]. The default is
274361
set to an arbitrary minimum which is usually enough but may need to
275362
be increased.
276363
"""
@@ -295,6 +382,9 @@ def _check(self, result: ValuationResult) -> Status:
295382
return Status.Converged
296383
return Status.Pending
297384

385+
def __str__(self):
386+
return f"AbsoluteStandardError(threshold={self.threshold}, fraction={self.fraction}, burn_in={self.burn_in})"
387+
298388

299389
class StandardError(AbsoluteStandardError):
300390
@deprecated(target=AbsoluteStandardError, deprecated_in="0.6.0", remove_in="0.8.0")
@@ -333,6 +423,9 @@ def completion(self) -> float:
333423
return min(1.0, self._count / self.n_checks)
334424
return 0.0
335425

426+
def __str__(self):
427+
return f"MaxChecks(n_checks={self.n_checks})"
428+
336429

337430
class MaxUpdates(StoppingCriterion):
338431
"""Terminate if any number of value updates exceeds or equals the given
@@ -377,6 +470,9 @@ def completion(self) -> float:
377470
return self.last_max / self.n_updates
378471
return 0.0
379472

473+
def __str__(self):
474+
return f"MaxUpdates(n_updates={self.n_updates})"
475+
380476

381477
class MinUpdates(StoppingCriterion):
382478
"""Terminate as soon as all value updates exceed or equal the given threshold.
@@ -414,6 +510,9 @@ def completion(self) -> float:
414510
return self.last_min / self.n_updates
415511
return 0.0
416512

513+
def __str__(self):
514+
return f"MinUpdates(n_updates={self.n_updates})"
515+
417516

418517
class MaxTime(StoppingCriterion):
419518
"""Terminate if the computation time exceeds the given number of seconds.
@@ -447,6 +546,9 @@ def completion(self) -> float:
447546
return 0.0
448547
return (time() - self.start) / self.max_seconds
449548

549+
def __str__(self):
550+
return f"MaxTime(seconds={self.max_seconds})"
551+
450552

451553
class HistoryDeviation(StoppingCriterion):
452554
r"""A simple check for relative distance to a previous step in the
@@ -527,3 +629,6 @@ def _check(self, r: ValuationResult) -> Status:
527629
if np.all(self._converged):
528630
return Status.Converged
529631
return Status.Pending
632+
633+
def __str__(self):
634+
return f"HistoryDeviation(n_steps={self.n_steps}, rtol={self.rtol})"

0 commit comments

Comments
 (0)