1- """
1+ r """
22Stopping criteria for value computations.
33
4- This module provides a basic set of stopping criteria, like [MaxUpdates][pydvl.value.stopping.MaxUpdates],
5- [MaxTime][pydvl.value.stopping.MaxTime], or [HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others.
6- These can behave in different ways depending on the context.
7- For example, [MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
4+ This module provides a basic set of stopping criteria, like
5+ [MaxUpdates][pydvl.value.stopping.MaxUpdates],
6+ [MaxTime][pydvl.value.stopping.MaxTime], or
7+ [HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others. These
8+ can behave in different ways depending on the context. For example,
9+ [MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
810the number of updates to values, which depending on the algorithm may mean a
911different number of utility evaluations or imply other computations like solving
1012a linear or quadratic program.
1113
12- # Creating stopping criteria
14+ Stopping criteria are callables that are evaluated on a
15+ [ValuationResult][pydvl.value.result.ValuationResult] and return a
16+ [Status][pydvl.utils.status.Status] object. They can be combined using boolean
17+ operators.
18+
19+ ## How convergence is determined
20+
21+ Most stopping criteria keep track of the convergence of each index separately
22+ but make global decisions based on the overall convergence of some fraction of
23+ all indices. For example, if we have a stopping criterion that checks whether
24+ the standard error of 90% of values is below a threshold, then methods will keep
25+ updating **all** indices until 90% of them have converged, irrespective of the
26+ quality of the individual estimates, and *without freezing updates* for indices
27+ along the way as values individually attain low standard error.
28+
29+ This has some practical implications, because some values do tend to converge
30+ sooner than others. For example, assume we use the criterion
31+ `AbsoluteStandardError(0.02) | MaxUpdates(1000)`. Then values close to 0 might
32+ be marked as "converged" rather quickly because they fulfill the first
33+ criterion, say after 20 iterations, despite being poor estimates. Because other
34+ indices take much longer to have low standard error and the criterion is a
35+ global check, the "converged" ones keep being updated and end up being good
36+ estimates. In this case, this has been beneficial, but one might not wish for
37+ converged values to be updated, if one is sure that the criterion is adequate
38+ for individual values.
39+
40+ [Semi-value methods][pydvl.value.semivalues] include a parameter
41+ `skip_converged` that allows to skip the computation of values that have
42+ converged. The way to avoid doing this too early is to use a more stringent
43+ check, e.g. `AbsoluteStandardError(1e-3) | MaxUpdates(1000)`. With
44+ `skip_converged=True` this check can still take less time than the first one,
45+ despite requiring more iterations for some indices.
46+
47+
48+ ## Choosing a stopping criterion
49+
50+ The choice of a stopping criterion greatly depends on the algorithm and the
51+ context. A safe bet is to combine a [MaxUpdates][pydvl.value.stopping.MaxUpdates]
52+ or a [MaxTime][pydvl.value.stopping.MaxTime] with a
53+ [HistoryDeviation][pydvl.value.stopping.HistoryDeviation] or an
54+ [AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError]. The former
55+ will ensure that the computation does not run for too long, while the latter
56+ will try to achieve results that are stable enough. Note however that if the
57+ threshold is too strict, one will always end up running until a maximum number
58+ of iterations or time. Also keep in mind that different values converge at
59+ different times, so you might want to use tight thresholds and `skip_converged`
60+ as described above for semi-values.
61+
62+
63+ ??? Example
64+ ```python
65+ from pydvl.value import AbsoluteStandardError, MaxUpdates, compute_banzhaf_semivalues
66+
67+ utility = ... # some utility object
68+ criterion = AbsoluteStandardError(threshold=1e-3, burn_in=32) | MaxUpdates(1000)
69+ values = compute_banzhaf_semivalues(
70+ utility,
71+ criterion,
72+ skip_converged=True, # skip values that have converged (CAREFUL!)
73+ )
74+ ```
75+ This will compute the Banzhaf semivalues for `utility` until either the
76+ absolute standard error is below `1e-3` or `1000` updates have been
77+ performed. The `burn_in` parameter is used to discard the first `32` updates
78+ from the computation of the standard error. The `skip_converged` parameter
79+ is used to avoid computing more marginals for indices that have converged,
80+ which is useful if
81+ [AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError] is met
82+ before [MaxUpdates][pydvl.value.stopping.MaxUpdates] for some indices.
83+
84+ !!! Warning
85+ Be careful not to reuse the same stopping criterion for different
86+ computations. The object has state and will not be reset between calls to
87+ value computation methods. If you need to reuse the same criterion, you
88+ should create a new instance.
89+
90+
91+ ## Creating stopping criteria
1392
1493The easiest way is to declare a function implementing the interface
1594[StoppingCriterionCallable][pydvl.value.stopping.StoppingCriterionCallable] and
1897that can be composed with other stopping criteria.
1998
2099Alternatively, and in particular if reporting of completion is required, one can
21- inherit from this class and implement the abstract methods
22- [_check][pydvl.value.stopping.StoppingCriterion._check] and
100+ inherit from this class and implement the abstract methods `_check` and
23101[completion][pydvl.value.stopping.StoppingCriterion.completion].
24102
25- # Composing stopping criteria
103+ ## Combining stopping criteria
26104
27105Objects of type [StoppingCriterion][pydvl.value.stopping.StoppingCriterion] can
28- be composed with the binary operators `&` (*and*), and `|` (*or*), following the
106+ be combined with the binary operators `&` (*and*), and `|` (*or*), following the
29107truth tables of [Status][pydvl.utils.status.Status]. The unary operator `~`
30108(*not*) is also supported. See
31109[StoppingCriterion][pydvl.value.stopping.StoppingCriterion] for details on how
32110these operations affect the behavior of the stopping criteria.
33111
112+
34113## References
35114
36115[^1]: <a name="ghorbani_data_2019"></a>Ghorbani, A., Zou, J., 2019.
@@ -163,6 +242,15 @@ def converged(self) -> NDArray[np.bool_]:
163242
164243 @property
165244 def name (self ):
245+ log = logging .getLogger (__name__ )
246+ # This string for the benefit of deprecation searches:
247+ # remove_in="0.8.0"
248+ log .warning (
249+ "The `name` attribute of `StoppingCriterion` is deprecated and will be removed in 0.8.0. "
250+ )
251+ return getattr (self , "_name" , type (self ).__name__ )
252+
253+ def __str__ (self ):
166254 return type (self ).__name__
167255
168256 def __call__ (self , result : ValuationResult ) -> Status :
@@ -182,23 +270,23 @@ def __and__(self, other: "StoppingCriterion") -> "StoppingCriterion":
182270 fun = lambda result : self ._check (result ) & other ._check (result ),
183271 converged = lambda : self .converged & other .converged ,
184272 completion = lambda : min (self .completion (), other .completion ()),
185- name = f"Composite StoppingCriterion: { self . name } AND { other . name } " ,
273+ name = f"Composite StoppingCriterion: { str ( self ) } AND { str ( other ) } " ,
186274 )(modify_result = self .modify_result or other .modify_result )
187275
188276 def __or__ (self , other : "StoppingCriterion" ) -> "StoppingCriterion" :
189277 return make_criterion (
190278 fun = lambda result : self ._check (result ) | other ._check (result ),
191279 converged = lambda : self .converged | other .converged ,
192280 completion = lambda : max (self .completion (), other .completion ()),
193- name = f"Composite StoppingCriterion: { self . name } OR { other . name } " ,
281+ name = f"Composite StoppingCriterion: { str ( self ) } OR { str ( other ) } " ,
194282 )(modify_result = self .modify_result or other .modify_result )
195283
196284 def __invert__ (self ) -> "StoppingCriterion" :
197285 return make_criterion (
198286 fun = lambda result : ~ self ._check (result ),
199287 converged = lambda : ~ self .converged ,
200288 completion = lambda : 1 - self .completion (),
201- name = f"Composite StoppingCriterion: NOT { self . name } " ,
289+ name = f"Composite StoppingCriterion: NOT { str ( self ) } " ,
202290 )(modify_result = self .modify_result )
203291
204292
@@ -239,8 +327,7 @@ def converged(self) -> NDArray[np.bool_]:
239327 return super ().converged
240328 return converged ()
241329
242- @property
243- def name (self ):
330+ def __str__ (self ):
244331 return self ._name
245332
246333 def completion (self ) -> float :
@@ -254,13 +341,13 @@ def completion(self) -> float:
254341class AbsoluteStandardError (StoppingCriterion ):
255342 r"""Determine convergence based on the standard error of the values.
256343
257- If $s_i$ is the standard error for datum $i$ and $v_i$ its value , then this
258- criterion returns [Converged][pydvl.utils.status.Status] if
259- $s_i < \epsilon$ for all $i$ and a threshold value $\epsilon \gt 0$.
344+ If $s_i$ is the standard error for datum $i$, then this criterion returns
345+ [Converged][pydvl.utils.status.Status] if $s_i < \epsilon$ for all $i$ and a
346+ threshold value $\epsilon \gt 0$.
260347
261348 Args:
262349 threshold: A value is considered to have converged if the standard
263- error is below this value . A way of choosing it is to pick some
350+ error is below this threshold . A way of choosing it is to pick some
264351 percentage of the range of the values. For Shapley values this is
265352 the difference between the maximum and minimum of the utility
266353 function (to see this substitute the maximum and minimum values of
@@ -270,7 +357,7 @@ class AbsoluteStandardError(StoppingCriterion):
270357 burn_in: The number of iterations to ignore before checking for
271358 convergence. This is required because computations typically start
272359 with zero variance, as a result of using
273- [empty ()][pydvl.value.result.ValuationResult.empty ]. The default is
360+ [zeros ()][pydvl.value.result.ValuationResult.zeros ]. The default is
274361 set to an arbitrary minimum which is usually enough but may need to
275362 be increased.
276363 """
@@ -295,6 +382,9 @@ def _check(self, result: ValuationResult) -> Status:
295382 return Status .Converged
296383 return Status .Pending
297384
385+ def __str__ (self ):
386+ return f"AbsoluteStandardError(threshold={ self .threshold } , fraction={ self .fraction } , burn_in={ self .burn_in } )"
387+
298388
299389class StandardError (AbsoluteStandardError ):
300390 @deprecated (target = AbsoluteStandardError , deprecated_in = "0.6.0" , remove_in = "0.8.0" )
@@ -333,6 +423,9 @@ def completion(self) -> float:
333423 return min (1.0 , self ._count / self .n_checks )
334424 return 0.0
335425
426+ def __str__ (self ):
427+ return f"MaxChecks(n_checks={ self .n_checks } )"
428+
336429
337430class MaxUpdates (StoppingCriterion ):
338431 """Terminate if any number of value updates exceeds or equals the given
@@ -377,6 +470,9 @@ def completion(self) -> float:
377470 return self .last_max / self .n_updates
378471 return 0.0
379472
473+ def __str__ (self ):
474+ return f"MaxUpdates(n_updates={ self .n_updates } )"
475+
380476
381477class MinUpdates (StoppingCriterion ):
382478 """Terminate as soon as all value updates exceed or equal the given threshold.
@@ -414,6 +510,9 @@ def completion(self) -> float:
414510 return self .last_min / self .n_updates
415511 return 0.0
416512
513+ def __str__ (self ):
514+ return f"MinUpdates(n_updates={ self .n_updates } )"
515+
417516
418517class MaxTime (StoppingCriterion ):
419518 """Terminate if the computation time exceeds the given number of seconds.
@@ -447,6 +546,9 @@ def completion(self) -> float:
447546 return 0.0
448547 return (time () - self .start ) / self .max_seconds
449548
549+ def __str__ (self ):
550+ return f"MaxTime(seconds={ self .max_seconds } )"
551+
450552
451553class HistoryDeviation (StoppingCriterion ):
452554 r"""A simple check for relative distance to a previous step in the
@@ -527,3 +629,6 @@ def _check(self, r: ValuationResult) -> Status:
527629 if np .all (self ._converged ):
528630 return Status .Converged
529631 return Status .Pending
632+
633+ def __str__ (self ):
634+ return f"HistoryDeviation(n_steps={ self .n_steps } , rtol={ self .rtol } )"
0 commit comments