2222 compute_generic_semivalues ,
2323 shapley_coefficient ,
2424)
25- from pydvl .value .stopping import AbsoluteStandardError , MaxUpdates
25+ from pydvl .value .stopping import HistoryDeviation , MaxUpdates
2626
2727from . import check_values
2828from .utils import timed
@@ -50,7 +50,7 @@ def test_shapley(
5050 parallel_config : ParallelConfig ,
5151):
5252 u , exact_values = analytic_shapley
53- criterion = AbsoluteStandardError ( 1e-4 ) | MaxUpdates (1000 )
53+ criterion = HistoryDeviation ( 50 , 1e-3 ) | MaxUpdates (1000 )
5454 values = compute_generic_semivalues (
5555 sampler (u .data .indices ),
5656 u ,
@@ -83,7 +83,7 @@ def test_shapley_batch_size(
8383 sampler (u .data .indices , seed = seed ),
8484 u ,
8585 coefficient ,
86- done = AbsoluteStandardError ( 1e-4 ) | MaxUpdates (1000 ),
86+ done = HistoryDeviation ( 50 , 1e-3 ) | MaxUpdates (1000 ),
8787 skip_converged = True ,
8888 n_jobs = n_jobs ,
8989 batch_size = 1 ,
@@ -95,7 +95,7 @@ def test_shapley_batch_size(
9595 sampler (u .data .indices , seed = seed ),
9696 u ,
9797 coefficient ,
98- done = AbsoluteStandardError ( 1e-4 ) | MaxUpdates (1000 ),
98+ done = HistoryDeviation ( 50 , 1e-3 ) | MaxUpdates (1000 ),
9999 skip_converged = True ,
100100 n_jobs = n_jobs ,
101101 batch_size = batch_size ,
@@ -128,7 +128,7 @@ def test_banzhaf(
128128 parallel_config : ParallelConfig ,
129129):
130130 u , exact_values = analytic_banzhaf
131- criterion = AbsoluteStandardError ( 1e-4 , burn_in = 32 ) | MaxUpdates (1000 )
131+ criterion = HistoryDeviation ( 50 , 1e-3 ) | MaxUpdates (1000 )
132132 values = compute_generic_semivalues (
133133 sampler (u .data .indices ),
134134 u ,
0 commit comments