Skip to content

Commit a27a655

Browse files
authored
Merge branch 'develop' into 259-implement-class-wise-shapley
2 parents 94d95db + 4c10cc3 commit a27a655

File tree

7 files changed

+176
-36
lines changed

7 files changed

+176
-36
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
[PR #338](https://github.com/aai-institute/pyDVL/pull/338)
77
- No longer using docker within tests to start a memcached server
88
[PR #444](https://github.com/aai-institute/pyDVL/pull/444)
9+
- Faster semi-value computation with per-index check of stopping criteria (optional)
10+
[PR #437](https://github.com/aai-institute/pyDVL/pull/437)
911
- Improvements and fixes to notebooks
1012
[PR #436](https://github.com/aai-institute/pyDVL/pull/436)
1113
- Fix initialization of `data_names` in `ValuationResult.zeros()`

src/pydvl/reporting/plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def plot_ci_array(
104104
means = np.mean(data, axis=0)
105105
variances = np.var(data, axis=0, ddof=1)
106106

107-
dummy: ValuationResult[np.int_, np.object_] = ValuationResult(
107+
dummy = ValuationResult[np.int_, np.object_](
108108
algorithm="dummy",
109109
values=means,
110110
variances=variances,

src/pydvl/value/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __len__(self) -> int:
200200
return len(self._outer_indices)
201201

202202
def __str__(self):
203-
return f"{self.__class__.__name__}"
203+
return self.__class__.__name__
204204

205205
def __repr__(self):
206206
return f"{self.__class__.__name__}({self._indices}, {self._outer_indices})"

src/pydvl/value/semivalues.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
import warnings
9393
from enum import Enum
9494
from itertools import islice
95-
from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast
95+
from typing import Iterable, List, Optional, Protocol, Tuple, Type, cast
9696

9797
import scipy as sp
9898
from deprecate import deprecated
@@ -143,7 +143,7 @@ def __call__(self, n: int, k: int) -> float:
143143

144144

145145
def _marginal(
146-
u: Utility, coefficient: SVCoefficient, samples: Collection[SampleT]
146+
u: Utility, coefficient: SVCoefficient, samples: Iterable[SampleT]
147147
) -> Tuple[MarginalT, ...]:
148148
"""Computation of marginal utility. This is a helper function for
149149
[compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues].
@@ -178,6 +178,7 @@ def compute_generic_semivalues(
178178
done: StoppingCriterion,
179179
*,
180180
batch_size: int = 1,
181+
skip_converged: bool = False,
181182
n_jobs: int = 1,
182183
config: ParallelConfig = ParallelConfig(),
183184
progress: bool = False,
@@ -190,6 +191,15 @@ def compute_generic_semivalues(
190191
coefficient: The semi-value coefficient
191192
done: Stopping criterion.
192193
batch_size: Number of marginal evaluations per single parallel job.
194+
skip_converged: Whether to skip marginal evaluations for indices that
195+
have already converged. **CAUTION**: This is only entirely safe if
196+
the stopping criterion is [MaxUpdates][pydvl.value.stopping.MaxUpdates].
197+
For any other stopping criterion, the convergence status of indices
198+
may change during the computation, or they may be marked as having
199+
converged even though in fact the estimated values are far from the
200+
true values (e.g. for
201+
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError],
202+
you will probably have to carefully adjust the threshold).
193203
n_jobs: Number of parallel jobs to use.
194204
config: Object configuring parallel computation, with cluster
195205
address, number of cpus, etc.
@@ -254,16 +264,33 @@ def compute_generic_semivalues(
254264

255265
# Ensure that we always have n_submitted_jobs running
256266
try:
257-
for _ in range(n_submitted_jobs - len(pending)):
267+
while len(pending) < n_submitted_jobs:
258268
samples = tuple(islice(sampler_it, batch_size))
259269
if len(samples) == 0:
260270
raise StopIteration
261271

262-
pending.add(
263-
executor.submit(
264-
_marginal, u=u, coefficient=correction, samples=samples
272+
# Filter out samples for indices that have already converged
273+
filtered_samples = samples
274+
if skip_converged and len(done.converged) > 0:
275+
# cloudpickle can't pickle this on python 3.8:
276+
# filtered_samples = filter(
277+
# lambda t: not done.converged[t[0]], samples
278+
# )
279+
filtered_samples = tuple(
280+
(idx, sample)
281+
for idx, sample in samples
282+
if not done.converged[idx]
283+
)
284+
285+
if filtered_samples:
286+
pending.add(
287+
executor.submit(
288+
_marginal,
289+
u=u,
290+
coefficient=correction,
291+
samples=filtered_samples,
292+
)
265293
)
266-
)
267294
except StopIteration:
268295
if len(pending) == 0:
269296
return result

src/pydvl/value/stopping.py

Lines changed: 126 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,94 @@
1-
"""
1+
r"""
22
Stopping criteria for value computations.
33
4-
This module provides a basic set of stopping criteria, like [MaxUpdates][pydvl.value.stopping.MaxUpdates],
5-
[MaxTime][pydvl.value.stopping.MaxTime], or [HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others.
6-
These can behave in different ways depending on the context.
7-
For example, [MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
4+
This module provides a basic set of stopping criteria, like
5+
[MaxUpdates][pydvl.value.stopping.MaxUpdates],
6+
[MaxTime][pydvl.value.stopping.MaxTime], or
7+
[HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others. These
8+
can behave in different ways depending on the context. For example,
9+
[MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
810
the number of updates to values, which depending on the algorithm may mean a
911
different number of utility evaluations or imply other computations like solving
1012
a linear or quadratic program.
1113
12-
# Creating stopping criteria
14+
Stopping criteria are callables that are evaluated on a
15+
[ValuationResult][pydvl.value.result.ValuationResult] and return a
16+
[Status][pydvl.utils.status.Status] object. They can be combined using boolean
17+
operators.
18+
19+
## How convergence is determined
20+
21+
Most stopping criteria keep track of the convergence of each index separately
22+
but make global decisions based on the overall convergence of some fraction of
23+
all indices. For example, if we have a stopping criterion that checks whether
24+
the standard error of 90% of values is below a threshold, then methods will keep
25+
updating **all** indices until 90% of them have converged, irrespective of the
26+
quality of the individual estimates, and *without freezing updates* for indices
27+
along the way as values individually attain low standard error.
28+
29+
This has some practical implications, because some values do tend to converge
30+
sooner than others. For example, assume we use the criterion
31+
`AbsoluteStandardError(0.02) | MaxUpdates(1000)`. Then values close to 0 might
32+
be marked as "converged" rather quickly because they fulfill the first
33+
criterion, say after 20 iterations, despite being poor estimates. Because other
34+
indices take much longer to have low standard error and the criterion is a
35+
global check, the "converged" ones keep being updated and end up being good
36+
estimates. In this case, this has been beneficial, but one might not wish for
37+
converged values to be updated, if one is sure that the criterion is adequate
38+
for individual values.
39+
40+
[Semi-value methods][pydvl.value.semivalues] include a parameter
41+
`skip_converged` that allows to skip the computation of values that have
42+
converged. The way to avoid doing this too early is to use a more stringent
43+
check, e.g. `AbsoluteStandardError(1e-3) | MaxUpdates(1000)`. With
44+
`skip_converged=True` this check can still take less time than the first one,
45+
despite requiring more iterations for some indices.
46+
47+
48+
## Choosing a stopping criterion
49+
50+
The choice of a stopping criterion greatly depends on the algorithm and the
51+
context. A safe bet is to combine a [MaxUpdates][pydvl.value.stopping.MaxUpdates]
52+
or a [MaxTime][pydvl.value.stopping.MaxTime] with a
53+
[HistoryDeviation][pydvl.value.stopping.HistoryDeviation] or an
54+
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError]. The former
55+
will ensure that the computation does not run for too long, while the latter
56+
will try to achieve results that are stable enough. Note however that if the
57+
threshold is too strict, one will always end up running until a maximum number
58+
of iterations or time. Also keep in mind that different values converge at
59+
different times, so you might want to use tight thresholds and `skip_converged`
60+
as described above for semi-values.
61+
62+
63+
??? Example
64+
```python
65+
from pydvl.value import AbsoluteStandardError, MaxUpdates, compute_banzhaf_semivalues
66+
67+
utility = ... # some utility object
68+
criterion = AbsoluteStandardError(threshold=1e-3, burn_in=32) | MaxUpdates(1000)
69+
values = compute_banzhaf_semivalues(
70+
utility,
71+
criterion,
72+
skip_converged=True, # skip values that have converged (CAREFUL!)
73+
)
74+
```
75+
This will compute the Banzhaf semivalues for `utility` until either the
76+
absolute standard error is below `1e-3` or `1000` updates have been
77+
performed. The `burn_in` parameter is used to discard the first `32` updates
78+
from the computation of the standard error. The `skip_converged` parameter
79+
is used to avoid computing more marginals for indices that have converged,
80+
which is useful if
81+
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError] is met
82+
before [MaxUpdates][pydvl.value.stopping.MaxUpdates] for some indices.
83+
84+
!!! Warning
85+
Be careful not to reuse the same stopping criterion for different
86+
computations. The object has state and will not be reset between calls to
87+
value computation methods. If you need to reuse the same criterion, you
88+
should create a new instance.
89+
90+
91+
## Creating stopping criteria
1392
1493
The easiest way is to declare a function implementing the interface
1594
[StoppingCriterionCallable][pydvl.value.stopping.StoppingCriterionCallable] and
@@ -18,19 +97,19 @@
1897
that can be composed with other stopping criteria.
1998
2099
Alternatively, and in particular if reporting of completion is required, one can
21-
inherit from this class and implement the abstract methods
22-
[_check][pydvl.value.stopping.StoppingCriterion._check] and
100+
inherit from this class and implement the abstract methods `_check` and
23101
[completion][pydvl.value.stopping.StoppingCriterion.completion].
24102
25-
# Composing stopping criteria
103+
## Combining stopping criteria
26104
27105
Objects of type [StoppingCriterion][pydvl.value.stopping.StoppingCriterion] can
28-
be composed with the binary operators `&` (*and*), and `|` (*or*), following the
106+
be combined with the binary operators `&` (*and*), and `|` (*or*), following the
29107
truth tables of [Status][pydvl.utils.status.Status]. The unary operator `~`
30108
(*not*) is also supported. See
31109
[StoppingCriterion][pydvl.value.stopping.StoppingCriterion] for details on how
32110
these operations affect the behavior of the stopping criteria.
33111
112+
34113
## References
35114
36115
[^1]: <a name="ghorbani_data_2019"></a>Ghorbani, A., Zou, J., 2019.
@@ -166,6 +245,15 @@ def converged(self) -> NDArray[np.bool_]:
166245

167246
@property
168247
def name(self):
248+
log = logging.getLogger(__name__)
249+
# This string for the benefit of deprecation searches:
250+
# remove_in="0.8.0"
251+
log.warning(
252+
"The `name` attribute of `StoppingCriterion` is deprecated and will be removed in 0.8.0. "
253+
)
254+
return getattr(self, "_name", type(self).__name__)
255+
256+
def __str__(self):
169257
return type(self).__name__
170258

171259
def __call__(self, result: ValuationResult) -> Status:
@@ -185,23 +273,23 @@ def __and__(self, other: "StoppingCriterion") -> "StoppingCriterion":
185273
fun=lambda result: self._check(result) & other._check(result),
186274
converged=lambda: self.converged & other.converged,
187275
completion=lambda: min(self.completion(), other.completion()),
188-
name=f"Composite StoppingCriterion: {self.name} AND {other.name}",
276+
name=f"Composite StoppingCriterion: {str(self)} AND {str(other)}",
189277
)(modify_result=self.modify_result or other.modify_result)
190278

191279
def __or__(self, other: "StoppingCriterion") -> "StoppingCriterion":
192280
return make_criterion(
193281
fun=lambda result: self._check(result) | other._check(result),
194282
converged=lambda: self.converged | other.converged,
195283
completion=lambda: max(self.completion(), other.completion()),
196-
name=f"Composite StoppingCriterion: {self.name} OR {other.name}",
284+
name=f"Composite StoppingCriterion: {str(self)} OR {str(other)}",
197285
)(modify_result=self.modify_result or other.modify_result)
198286

199287
def __invert__(self) -> "StoppingCriterion":
200288
return make_criterion(
201289
fun=lambda result: ~self._check(result),
202290
converged=lambda: ~self.converged,
203291
completion=lambda: 1 - self.completion(),
204-
name=f"Composite StoppingCriterion: NOT {self.name}",
292+
name=f"Composite StoppingCriterion: NOT {str(self)}",
205293
)(modify_result=self.modify_result)
206294

207295

@@ -242,8 +330,7 @@ def converged(self) -> NDArray[np.bool_]:
242330
return super().converged
243331
return converged()
244332

245-
@property
246-
def name(self):
333+
def __str__(self):
247334
return self._name
248335

249336
def completion(self) -> float:
@@ -257,13 +344,13 @@ def completion(self) -> float:
257344
class AbsoluteStandardError(StoppingCriterion):
258345
r"""Determine convergence based on the standard error of the values.
259346
260-
If $s_i$ is the standard error for datum $i$ and $v_i$ its value, then this
261-
criterion returns [Converged][pydvl.utils.status.Status] if
262-
$s_i < \epsilon$ for all $i$ and a threshold value $\epsilon \gt 0$.
347+
If $s_i$ is the standard error for datum $i$, then this criterion returns
348+
[Converged][pydvl.utils.status.Status] if $s_i < \epsilon$ for all $i$ and a
349+
threshold value $\epsilon \gt 0$.
263350
264351
Args:
265352
threshold: A value is considered to have converged if the standard
266-
error is below this value. A way of choosing it is to pick some
353+
error is below this threshold. A way of choosing it is to pick some
267354
percentage of the range of the values. For Shapley values this is
268355
the difference between the maximum and minimum of the utility
269356
function (to see this substitute the maximum and minimum values of
@@ -273,7 +360,7 @@ class AbsoluteStandardError(StoppingCriterion):
273360
burn_in: The number of iterations to ignore before checking for
274361
convergence. This is required because computations typically start
275362
with zero variance, as a result of using
276-
[empty()][pydvl.value.result.ValuationResult.empty]. The default is
363+
[zeros()][pydvl.value.result.ValuationResult.zeros]. The default is
277364
set to an arbitrary minimum which is usually enough but may need to
278365
be increased.
279366
"""
@@ -298,6 +385,9 @@ def _check(self, result: ValuationResult) -> Status:
298385
return Status.Converged
299386
return Status.Pending
300387

388+
def __str__(self):
389+
return f"AbsoluteStandardError(threshold={self.threshold}, fraction={self.fraction}, burn_in={self.burn_in})"
390+
301391

302392
class StandardError(AbsoluteStandardError):
303393
@deprecated(target=AbsoluteStandardError, deprecated_in="0.6.0", remove_in="0.8.0")
@@ -339,6 +429,9 @@ def completion(self) -> float:
339429
def reset(self):
340430
self._count = 0
341431

432+
def __str__(self):
433+
return f"MaxChecks(n_checks={self.n_checks})"
434+
342435

343436
class MaxUpdates(StoppingCriterion):
344437
"""Terminate if any number of value updates exceeds or equals the given
@@ -383,6 +476,9 @@ def completion(self) -> float:
383476
return self.last_max / self.n_updates
384477
return 0.0
385478

479+
def __str__(self):
480+
return f"MaxUpdates(n_updates={self.n_updates})"
481+
386482

387483
class MinUpdates(StoppingCriterion):
388484
"""Terminate as soon as all value updates exceed or equal the given threshold.
@@ -420,6 +516,9 @@ def completion(self) -> float:
420516
return self.last_min / self.n_updates
421517
return 0.0
422518

519+
def __str__(self):
520+
return f"MinUpdates(n_updates={self.n_updates})"
521+
423522

424523
class MaxTime(StoppingCriterion):
425524
"""Terminate if the computation time exceeds the given number of seconds.
@@ -455,6 +554,10 @@ def completion(self) -> float:
455554

456555
def reset(self):
457556
self.start = time()
557+
558+
def __str__(self):
559+
return f"MaxTime(seconds={self.max_seconds})"
560+
458561

459562

460563
class HistoryDeviation(StoppingCriterion):
@@ -539,3 +642,6 @@ def _check(self, r: ValuationResult) -> Status:
539642

540643
def reset(self):
541644
self._memory = None # type: ignore
645+
646+
def __str__(self):
647+
return f"HistoryDeviation(n_steps={self.n_steps}, rtol={self.rtol})"

0 commit comments

Comments
 (0)