|
101 | 101 | from deprecate import deprecated |
102 | 102 | from tqdm import tqdm |
103 | 103 |
|
104 | | -from pydvl.parallel import ( |
105 | | - ParallelBackend, |
106 | | - _maybe_init_parallel_backend, |
107 | | - init_parallel_backend, |
108 | | -) |
| 104 | +from pydvl.parallel import ParallelBackend, _maybe_init_parallel_backend |
109 | 105 | from pydvl.parallel.config import ParallelConfig |
110 | 106 | from pydvl.utils import Utility |
111 | 107 | from pydvl.utils.types import IndexT, Seed |
|
117 | 113 | SampleT, |
118 | 114 | StochasticSampler, |
119 | 115 | ) |
120 | | -from pydvl.value.stopping import MaxUpdates, RankCorrelation, StoppingCriterion |
| 116 | +from pydvl.value.stopping import StoppingCriterion |
121 | 117 |
|
122 | 118 | __all__ = [ |
123 | 119 | "compute_banzhaf_semivalues", |
@@ -497,7 +493,7 @@ def beta_coefficient_w(n: int, k: int) -> float: |
497 | 493 | def compute_shapley_semivalues( |
498 | 494 | u: Utility, |
499 | 495 | *, |
500 | | - done: StoppingCriterion = MaxUpdates(100), |
| 496 | + done: StoppingCriterion, |
501 | 497 | sampler_t: Type[StochasticSampler] = PermutationSampler, |
502 | 498 | batch_size: int = 1, |
503 | 499 | n_jobs: int = 1, |
@@ -567,7 +563,7 @@ def compute_shapley_semivalues( |
567 | 563 | def compute_banzhaf_semivalues( |
568 | 564 | u: Utility, |
569 | 565 | *, |
570 | | - done: StoppingCriterion = MaxUpdates(100), |
| 566 | + done: StoppingCriterion, |
571 | 567 | sampler_t: Type[StochasticSampler] = PermutationSampler, |
572 | 568 | batch_size: int = 1, |
573 | 569 | n_jobs: int = 1, |
@@ -635,7 +631,7 @@ def compute_banzhaf_semivalues( |
635 | 631 | def compute_msr_banzhaf_semivalues( |
636 | 632 | u: Utility, |
637 | 633 | *, |
638 | | - done: StoppingCriterion = RankCorrelation(0.01), |
| 634 | + done: StoppingCriterion, |
639 | 635 | sampler_t: Type[StochasticSampler] = MSRSampler, |
640 | 636 | batch_size: int = 1, |
641 | 637 | n_jobs: int = 1, |
@@ -704,7 +700,7 @@ def compute_beta_shapley_semivalues( |
704 | 700 | *, |
705 | 701 | alpha: float = 1, |
706 | 702 | beta: float = 1, |
707 | | - done: StoppingCriterion = MaxUpdates(100), |
| 703 | + done: StoppingCriterion, |
708 | 704 | sampler_t: Type[StochasticSampler] = PermutationSampler, |
709 | 705 | batch_size: int = 1, |
710 | 706 | n_jobs: int = 1, |
@@ -782,7 +778,7 @@ class SemiValueMode(str, Enum): |
782 | 778 | def compute_semivalues( |
783 | 779 | u: Utility, |
784 | 780 | *, |
785 | | - done: StoppingCriterion = MaxUpdates(100), |
| 781 | + done: StoppingCriterion, |
786 | 782 | mode: SemiValueMode = SemiValueMode.Shapley, |
787 | 783 | sampler_t: Type[StochasticSampler] = PermutationSampler, |
788 | 784 | batch_size: int = 1, |
|
0 commit comments