Skip to content

Commit 1f7d7f2

Browse files
committed
Rename rank stability and homogeneize interface
1 parent 90ad81b commit 1f7d7f2

File tree

5 files changed

+30
-21
lines changed

5 files changed

+30
-21
lines changed

docs/value/semi-values.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,14 @@ $\mathbf{S}_{\not{\ni} i}$ are the subsets not containing the index $i$.
115115

116116
The function implementing this method is
117117
[compute_msr_banzhaf_semivalues][pydvl.value.semivalues.compute_msr_banzhaf_semivalues].
118+
118119
```python
119-
from pydvl.value import compute_msr_banzhaf_semivalues, RankStability, Utility
120+
from pydvl.value import compute_msr_banzhaf_semivalues, RankCorrelation, Utility
120121

121122
utility = Utility(model, data)
122123
values = compute_msr_banzhaf_semivalues(
123-
u=utility, done=RankStability(rtol=0.001),
124-
)
124+
u=utility, done=RankCorrelation(rtol=0.001),
125+
)
125126
```
126127
For further details on how to use this method and a comparison of the sample
127128
efficiency, we suggest to take a look at the example notebook

notebooks/msr_banzhaf_digits.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@
694694
"cell_type": "markdown",
695695
"metadata": {},
696696
"source": [
697-
"Computing the values is the same, but we now use a better stopping criterion. Instead of fixing the number of utility evaluations with [MaxChecks](../../api/pydvl/value/stopping/#pydvl.value.stopping.MaxChecks), we use [RankStability](../../api/pydvl/value/stopping/#pydvl.value.stopping.RankStability) to stop when the change in Spearman correlation between the ranking of two successive iterations is below a threshold. "
697+
"Computing the values is the same, but we now use a better stopping criterion. Instead of fixing the number of utility evaluations with [MaxChecks](../../api/pydvl/value/stopping/#pydvl.value.stopping.MaxChecks), we use [RankCorrelation](../../api/pydvl/value/stopping/#pydvl.value.stopping.RankCorrelation) to stop when the change in Spearman correlation between the ranking of two successive iterations is below a threshold. "
698698
]
699699
},
700700
{
@@ -715,7 +715,7 @@
715715
"source": [
716716
"values = compute_msr_banzhaf_semivalues(\n",
717717
" utility,\n",
718-
" done=RankStability(0.0001),\n",
718+
" done=RankCorrelation(rtol=0.0001, burn_in=10),\n",
719719
" n_jobs=n_jobs,\n",
720720
" progress=True,\n",
721721
")\n",

src/pydvl/value/semivalues.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
SampleT,
118118
StochasticSampler,
119119
)
120-
from pydvl.value.stopping import MaxUpdates, RankStability, StoppingCriterion
120+
from pydvl.value.stopping import MaxUpdates, RankCorrelation, StoppingCriterion
121121

122122
__all__ = [
123123
"compute_banzhaf_semivalues",
@@ -635,7 +635,7 @@ def compute_banzhaf_semivalues(
635635
def compute_msr_banzhaf_semivalues(
636636
u: Utility,
637637
*,
638-
done: StoppingCriterion = RankStability(0.01),
638+
done: StoppingCriterion = RankCorrelation(0.01),
639639
sampler_t: Type[StochasticSampler] = MSRSampler,
640640
batch_size: int = 1,
641641
n_jobs: int = 1,

src/pydvl/value/stopping.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@
143143
"MinUpdates",
144144
"MaxTime",
145145
"HistoryDeviation",
146-
"RankStability",
146+
"RankCorrelation",
147147
]
148148

149149
logger = logging.getLogger(__name__)
@@ -630,7 +630,7 @@ def __str__(self):
630630
return f"HistoryDeviation(n_steps={self.n_steps}, rtol={self.rtol})"
631631

632632

633-
class RankStability(StoppingCriterion):
633+
class RankCorrelation(StoppingCriterion):
634634
r"""A check for stability of Spearman correlation between checks.
635635
636636
When the change in rank correlation between two successive iterations is
@@ -645,23 +645,31 @@ class RankStability(StoppingCriterion):
645645
646646
Args:
647647
rtol: Relative tolerance for convergence ($\epsilon$ in the formula)
648+
modify_result: If `True`, the status of the input
649+
[ValuationResult][pydvl.value.result.ValuationResult] is modified in
650+
place after the call.
651+
burn_in: The minimum number of iterations before checking for
652+
convergence. This is required because the first correlation is
653+
meaningless.
654+
655+
!!! tip "Added in 0.9.0"
648656
"""
649657

650658
def __init__(
651659
self,
652660
rtol: float,
661+
burn_in: int,
653662
modify_result: bool = True,
654-
min_iterations: int = 10,
655663
):
656664
super().__init__(modify_result=modify_result)
657665
if rtol <= 0 or rtol >= 1:
658666
raise ValueError("rtol must be in (0, 1)")
659667
self.rtol = rtol
660-
self._memory = None # type: ignore
668+
self.burn_in = burn_in
669+
self._memory: NDArray[np.float_] | None = None
661670
self._corr = 0.0
662671
self._completion = 0.0
663672
self._iterations = 0
664-
self.min_iterations = min_iterations
665673

666674
def _check(self, r: ValuationResult) -> Status:
667675
self._iterations += 1
@@ -675,11 +683,11 @@ def _check(self, r: ValuationResult) -> Status:
675683
self._update_completion(corr)
676684
if (
677685
np.isclose(corr, self._corr, rtol=self.rtol)
678-
and self._iterations > self.min_iterations
686+
and self._iterations > self.burn_in
679687
):
680688
self._converged = np.full(len(r), True)
681689
logger.debug(
682-
f"RankStability has converged with {corr=} in iteration {self._iterations}"
690+
f"RankCorrelation has converged with {corr=} in iteration {self._iterations}"
683691
)
684692
return Status.Converged
685693
self._corr = np.nan_to_num(corr, nan=0.0)
@@ -702,4 +710,4 @@ def reset(self):
702710
self._corr = 0.0
703711

704712
def __str__(self):
705-
return f"RankStability(rtol={self.rtol})"
713+
return f"RankCorrelation(rtol={self.rtol})"

tests/value/test_stopping.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
MaxTime,
1414
MaxUpdates,
1515
MinUpdates,
16-
RankStability,
16+
RankCorrelation,
1717
StoppingCriterion,
1818
make_criterion,
1919
)
@@ -199,12 +199,12 @@ def test_max_checks():
199199
assert done(v)
200200

201201

202-
def test_rank_stability():
203-
"""Test the RankStability stopping criterion."""
202+
def test_rank_correlation():
203+
"""Test the RankCorrelation stopping criterion."""
204204
v = ValuationResult.zeros(indices=range(5))
205205
arr = np.arange(5)
206206

207-
done = RankStability(rtol=0.1)
207+
done = RankCorrelation(rtol=0.1)
208208
for i in range(20):
209209
arr = np.roll(arr, 1)
210210
for j in range(5):
@@ -213,14 +213,14 @@ def test_rank_stability():
213213
assert not done(v)
214214
assert done(v)
215215

216-
done = RankStability(rtol=0.1, min_iterations=3)
216+
done = RankCorrelation(rtol=0.1, burn_in=3)
217217
v = ValuationResult.from_random(size=5)
218218
assert not done(v)
219219
assert not done(v)
220220
assert not done(v)
221221
assert done(v)
222222

223-
done = RankStability(rtol=0.1, min_iterations=2)
223+
done = RankCorrelation(rtol=0.1, burn_in=2)
224224
v = ValuationResult.from_random(size=5)
225225
assert not done(v)
226226
assert not done(v)

0 commit comments

Comments
 (0)