Skip to content

Commit 8f25b53

Browse files
committed
Require explicit stopping criteria (impossible to provide good defaults)
1 parent 1f7d7f2 commit 8f25b53

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

src/pydvl/value/semivalues.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,7 @@
101101
from deprecate import deprecated
102102
from tqdm import tqdm
103103

104-
from pydvl.parallel import (
105-
ParallelBackend,
106-
_maybe_init_parallel_backend,
107-
init_parallel_backend,
108-
)
104+
from pydvl.parallel import ParallelBackend, _maybe_init_parallel_backend
109105
from pydvl.parallel.config import ParallelConfig
110106
from pydvl.utils import Utility
111107
from pydvl.utils.types import IndexT, Seed
@@ -117,7 +113,7 @@
117113
SampleT,
118114
StochasticSampler,
119115
)
120-
from pydvl.value.stopping import MaxUpdates, RankCorrelation, StoppingCriterion
116+
from pydvl.value.stopping import StoppingCriterion
121117

122118
__all__ = [
123119
"compute_banzhaf_semivalues",
@@ -497,7 +493,7 @@ def beta_coefficient_w(n: int, k: int) -> float:
497493
def compute_shapley_semivalues(
498494
u: Utility,
499495
*,
500-
done: StoppingCriterion = MaxUpdates(100),
496+
done: StoppingCriterion,
501497
sampler_t: Type[StochasticSampler] = PermutationSampler,
502498
batch_size: int = 1,
503499
n_jobs: int = 1,
@@ -567,7 +563,7 @@ def compute_shapley_semivalues(
567563
def compute_banzhaf_semivalues(
568564
u: Utility,
569565
*,
570-
done: StoppingCriterion = MaxUpdates(100),
566+
done: StoppingCriterion,
571567
sampler_t: Type[StochasticSampler] = PermutationSampler,
572568
batch_size: int = 1,
573569
n_jobs: int = 1,
@@ -635,7 +631,7 @@ def compute_banzhaf_semivalues(
635631
def compute_msr_banzhaf_semivalues(
636632
u: Utility,
637633
*,
638-
done: StoppingCriterion = RankCorrelation(0.01),
634+
done: StoppingCriterion,
639635
sampler_t: Type[StochasticSampler] = MSRSampler,
640636
batch_size: int = 1,
641637
n_jobs: int = 1,
@@ -704,7 +700,7 @@ def compute_beta_shapley_semivalues(
704700
*,
705701
alpha: float = 1,
706702
beta: float = 1,
707-
done: StoppingCriterion = MaxUpdates(100),
703+
done: StoppingCriterion,
708704
sampler_t: Type[StochasticSampler] = PermutationSampler,
709705
batch_size: int = 1,
710706
n_jobs: int = 1,
@@ -782,7 +778,7 @@ class SemiValueMode(str, Enum):
782778
def compute_semivalues(
783779
u: Utility,
784780
*,
785-
done: StoppingCriterion = MaxUpdates(100),
781+
done: StoppingCriterion,
786782
mode: SemiValueMode = SemiValueMode.Shapley,
787783
sampler_t: Type[StochasticSampler] = PermutationSampler,
788784
batch_size: int = 1,

src/pydvl/value/shapley/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
from pydvl.value.shapley.owen import OwenAlgorithm, owen_sampling_shapley
1717
from pydvl.value.shapley.truncated import NoTruncation
1818
from pydvl.value.shapley.types import ShapleyMode
19-
from pydvl.value.stopping import MaxUpdates, StoppingCriterion
19+
from pydvl.value.stopping import StoppingCriterion
2020

2121
__all__ = ["compute_shapley_values"]
2222

2323

2424
def compute_shapley_values(
2525
u: Utility,
2626
*,
27-
done: StoppingCriterion = MaxUpdates(100),
27+
done: StoppingCriterion,
2828
mode: ShapleyMode = ShapleyMode.TruncatedMontecarlo,
2929
n_jobs: int = 1,
3030
seed: Optional[Seed] = None,

0 commit comments

Comments
 (0)