1- """
1+ r """
22Stopping criteria for value computations.
33
4- This module provides a basic set of stopping criteria, like [MaxUpdates][pydvl.value.stopping.MaxUpdates],
5- [MaxTime][pydvl.value.stopping.MaxTime], or [HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others.
6- These can behave in different ways depending on the context.
7- For example, [MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
4+ This module provides a basic set of stopping criteria, like
5+ [MaxUpdates][pydvl.value.stopping.MaxUpdates],
6+ [MaxTime][pydvl.value.stopping.MaxTime], or
7+ [HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others. These
8+ can behave in different ways depending on the context. For example,
9+ [MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
810the number of updates to values, which depending on the algorithm may mean a
911different number of utility evaluations or imply other computations like solving
1012a linear or quadratic program.
1113
12- # Creating stopping criteria
14+ Stopping criteria are callables that are evaluated on a
15+ [ValuationResult][pydvl.value.result.ValuationResult] and return a
16+ [Status][pydvl.utils.status.Status] object. They can be combined using boolean
17+ operators.
18+
19+ ## How convergence is determined
20+
21+ Most stopping criteria keep track of the convergence of each index separately
22+ but make global decisions based on the overall convergence of some fraction of
23+ all indices. For example, if we have a stopping criterion that checks whether
24+ the standard error of 90% of values is below a threshold, then methods will keep
25+ updating **all** indices until 90% of them have converged, irrespective of the
26+ quality of the individual estimates, and *without freezing updates* for indices
27+ along the way as values individually attain low standard error.
28+
29+ This has some practical implications, because some values do tend to converge
30+ sooner than others. For example, assume we use the criterion
31+ `AbsoluteStandardError(0.02) | MaxUpdates(1000)`. Then values close to 0 might
32+ be marked as "converged" rather quickly because they fulfill the first
33+ criterion, say after 20 iterations, despite being poor estimates. Because other
34+ indices take much longer to have low standard error and the criterion is a
35+ global check, the "converged" ones keep being updated and end up being good
36+ estimates. In this case, this has been beneficial, but one might not wish for
37+ converged values to be updated, if one is sure that the criterion is adequate
38+ for individual values.
39+
40+ [Semi-value methods][pydvl.value.semivalues] include a parameter
41+ `skip_converged` that allows to skip the computation of values that have
42+ converged. The way to avoid doing this too early is to use a more stringent
43+ check, e.g. `AbsoluteStandardError(1e-3) | MaxUpdates(1000)`. With
44+ `skip_converged=True` this check can still take less time than the first one,
45+ despite requiring more iterations for some indices.
46+
47+
48+ ## Choosing a stopping criterion
49+
50+ The choice of a stopping criterion greatly depends on the algorithm and the
51+ context. A safe bet is to combine a [MaxUpdates][pydvl.value.stopping.MaxUpdates]
52+ or a [MaxTime][pydvl.value.stopping.MaxTime] with a
53+ [HistoryDeviation][pydvl.value.stopping.HistoryDeviation] or an
54+ [AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError]. The former
55+ will ensure that the computation does not run for too long, while the latter
56+ will try to achieve results that are stable enough. Note however that if the
57+ threshold is too strict, one will always end up running until a maximum number
58+ of iterations or time. Also keep in mind that different values converge at
59+ different times, so you might want to use tight thresholds and `skip_converged`
60+ as described above for semi-values.
61+
62+
63+ ??? Example
64+ ```python
65+ from pydvl.value import AbsoluteStandardError, MaxUpdates, compute_banzhaf_semivalues
66+
67+ utility = ... # some utility object
68+ criterion = AbsoluteStandardError(threshold=1e-3, burn_in=32) | MaxUpdates(1000)
69+ values = compute_banzhaf_semivalues(
70+ utility,
71+ criterion,
72+ skip_converged=True, # skip values that have converged (CAREFUL!)
73+ )
74+ ```
75+ This will compute the Banzhaf semivalues for `utility` until either the
76+ absolute standard error is below `1e-3` or `1000` updates have been
77+ performed. The `burn_in` parameter is used to discard the first `32` updates
78+ from the computation of the standard error. The `skip_converged` parameter
79+ is used to avoid computing more marginals for indices that have converged,
80+ which is useful if
81+ [AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError] is met
82+ before [MaxUpdates][pydvl.value.stopping.MaxUpdates] for some indices.
83+
84+ !!! Warning
85+ Be careful not to reuse the same stopping criterion for different
86+ computations. The object has state and will not be reset between calls to
87+ value computation methods. If you need to reuse the same criterion, you
88+ should create a new instance.
89+
90+
91+ ## Creating stopping criteria
1392
1493The easiest way is to declare a function implementing the interface
1594[StoppingCriterionCallable][pydvl.value.stopping.StoppingCriterionCallable] and
1897that can be composed with other stopping criteria.
1998
2099Alternatively, and in particular if reporting of completion is required, one can
21- inherit from this class and implement the abstract methods
22- [_check][pydvl.value.stopping.StoppingCriterion._check] and
100+ inherit from this class and implement the abstract methods `_check` and
23101[completion][pydvl.value.stopping.StoppingCriterion.completion].
24102
25- # Composing stopping criteria
103+ ## Combining stopping criteria
26104
27105Objects of type [StoppingCriterion][pydvl.value.stopping.StoppingCriterion] can
28- be composed with the binary operators `&` (*and*), and `|` (*or*), following the
106+ be combined with the binary operators `&` (*and*), and `|` (*or*), following the
29107truth tables of [Status][pydvl.utils.status.Status]. The unary operator `~`
30108(*not*) is also supported. See
31109[StoppingCriterion][pydvl.value.stopping.StoppingCriterion] for details on how
32110these operations affect the behavior of the stopping criteria.
33111
112+
34113## References
35114
36115[^1]: <a name="ghorbani_data_2019"></a>Ghorbani, A., Zou, J., 2019.
@@ -166,6 +245,15 @@ def converged(self) -> NDArray[np.bool_]:
166245
167246 @property
168247 def name (self ):
248+ log = logging .getLogger (__name__ )
249+ # This string for the benefit of deprecation searches:
250+ # remove_in="0.8.0"
251+ log .warning (
252+ "The `name` attribute of `StoppingCriterion` is deprecated and will be removed in 0.8.0. "
253+ )
254+ return getattr (self , "_name" , type (self ).__name__ )
255+
256+ def __str__ (self ):
169257 return type (self ).__name__
170258
171259 def __call__ (self , result : ValuationResult ) -> Status :
@@ -185,23 +273,23 @@ def __and__(self, other: "StoppingCriterion") -> "StoppingCriterion":
185273 fun = lambda result : self ._check (result ) & other ._check (result ),
186274 converged = lambda : self .converged & other .converged ,
187275 completion = lambda : min (self .completion (), other .completion ()),
188- name = f"Composite StoppingCriterion: { self . name } AND { other . name } " ,
276+ name = f"Composite StoppingCriterion: { str ( self ) } AND { str ( other ) } " ,
189277 )(modify_result = self .modify_result or other .modify_result )
190278
191279 def __or__ (self , other : "StoppingCriterion" ) -> "StoppingCriterion" :
192280 return make_criterion (
193281 fun = lambda result : self ._check (result ) | other ._check (result ),
194282 converged = lambda : self .converged | other .converged ,
195283 completion = lambda : max (self .completion (), other .completion ()),
196- name = f"Composite StoppingCriterion: { self . name } OR { other . name } " ,
284+ name = f"Composite StoppingCriterion: { str ( self ) } OR { str ( other ) } " ,
197285 )(modify_result = self .modify_result or other .modify_result )
198286
199287 def __invert__ (self ) -> "StoppingCriterion" :
200288 return make_criterion (
201289 fun = lambda result : ~ self ._check (result ),
202290 converged = lambda : ~ self .converged ,
203291 completion = lambda : 1 - self .completion (),
204- name = f"Composite StoppingCriterion: NOT { self . name } " ,
292+ name = f"Composite StoppingCriterion: NOT { str ( self ) } " ,
205293 )(modify_result = self .modify_result )
206294
207295
@@ -242,8 +330,7 @@ def converged(self) -> NDArray[np.bool_]:
242330 return super ().converged
243331 return converged ()
244332
245- @property
246- def name (self ):
333+ def __str__ (self ):
247334 return self ._name
248335
249336 def completion (self ) -> float :
@@ -257,13 +344,13 @@ def completion(self) -> float:
257344class AbsoluteStandardError (StoppingCriterion ):
258345 r"""Determine convergence based on the standard error of the values.
259346
260- If $s_i$ is the standard error for datum $i$ and $v_i$ its value , then this
261- criterion returns [Converged][pydvl.utils.status.Status] if
262- $s_i < \epsilon$ for all $i$ and a threshold value $\epsilon \gt 0$.
347+ If $s_i$ is the standard error for datum $i$, then this criterion returns
348+ [Converged][pydvl.utils.status.Status] if $s_i < \epsilon$ for all $i$ and a
349+ threshold value $\epsilon \gt 0$.
263350
264351 Args:
265352 threshold: A value is considered to have converged if the standard
266- error is below this value . A way of choosing it is to pick some
353+ error is below this threshold . A way of choosing it is to pick some
267354 percentage of the range of the values. For Shapley values this is
268355 the difference between the maximum and minimum of the utility
269356 function (to see this substitute the maximum and minimum values of
@@ -273,7 +360,7 @@ class AbsoluteStandardError(StoppingCriterion):
273360 burn_in: The number of iterations to ignore before checking for
274361 convergence. This is required because computations typically start
275362 with zero variance, as a result of using
276- [empty ()][pydvl.value.result.ValuationResult.empty ]. The default is
363+ [zeros ()][pydvl.value.result.ValuationResult.zeros ]. The default is
277364 set to an arbitrary minimum which is usually enough but may need to
278365 be increased.
279366 """
@@ -298,6 +385,9 @@ def _check(self, result: ValuationResult) -> Status:
298385 return Status .Converged
299386 return Status .Pending
300387
388+ def __str__ (self ):
389+ return f"AbsoluteStandardError(threshold={ self .threshold } , fraction={ self .fraction } , burn_in={ self .burn_in } )"
390+
301391
302392class StandardError (AbsoluteStandardError ):
303393 @deprecated (target = AbsoluteStandardError , deprecated_in = "0.6.0" , remove_in = "0.8.0" )
@@ -339,6 +429,9 @@ def completion(self) -> float:
339429 def reset (self ):
340430 self ._count = 0
341431
432+ def __str__ (self ):
433+ return f"MaxChecks(n_checks={ self .n_checks } )"
434+
342435
343436class MaxUpdates (StoppingCriterion ):
344437 """Terminate if any number of value updates exceeds or equals the given
@@ -383,6 +476,9 @@ def completion(self) -> float:
383476 return self .last_max / self .n_updates
384477 return 0.0
385478
479+ def __str__ (self ):
480+ return f"MaxUpdates(n_updates={ self .n_updates } )"
481+
386482
387483class MinUpdates (StoppingCriterion ):
388484 """Terminate as soon as all value updates exceed or equal the given threshold.
@@ -420,6 +516,9 @@ def completion(self) -> float:
420516 return self .last_min / self .n_updates
421517 return 0.0
422518
519+ def __str__ (self ):
520+ return f"MinUpdates(n_updates={ self .n_updates } )"
521+
423522
424523class MaxTime (StoppingCriterion ):
425524 """Terminate if the computation time exceeds the given number of seconds.
@@ -455,6 +554,10 @@ def completion(self) -> float:
455554
456555 def reset (self ):
457556 self .start = time ()
557+
558+ def __str__ (self ):
559+ return f"MaxTime(seconds={ self .max_seconds } )"
560+
458561
459562
460563class HistoryDeviation (StoppingCriterion ):
@@ -539,3 +642,6 @@ def _check(self, r: ValuationResult) -> Status:
539642
540643 def reset (self ):
541644 self ._memory = None # type: ignore
645+
646+ def __str__ (self ):
647+ return f"HistoryDeviation(n_steps={ self .n_steps } , rtol={ self .rtol } )"
0 commit comments