Skip to content

Commit 072424d

Browse files
committed
Use better criterion for tests
1 parent 84eb84e commit 072424d

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/value/test_semivalues.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
compute_generic_semivalues,
2323
shapley_coefficient,
2424
)
25-
from pydvl.value.stopping import AbsoluteStandardError, MaxUpdates
25+
from pydvl.value.stopping import HistoryDeviation, MaxUpdates
2626

2727
from . import check_values
2828
from .utils import timed
@@ -50,7 +50,7 @@ def test_shapley(
5050
parallel_config: ParallelConfig,
5151
):
5252
u, exact_values = analytic_shapley
53-
criterion = AbsoluteStandardError(1e-4) | MaxUpdates(1000)
53+
criterion = HistoryDeviation(50, 1e-3) | MaxUpdates(1000)
5454
values = compute_generic_semivalues(
5555
sampler(u.data.indices),
5656
u,
@@ -83,7 +83,7 @@ def test_shapley_batch_size(
8383
sampler(u.data.indices, seed=seed),
8484
u,
8585
coefficient,
86-
done=AbsoluteStandardError(1e-4) | MaxUpdates(1000),
86+
done=HistoryDeviation(50, 1e-3) | MaxUpdates(1000),
8787
skip_converged=True,
8888
n_jobs=n_jobs,
8989
batch_size=1,
@@ -95,7 +95,7 @@ def test_shapley_batch_size(
9595
sampler(u.data.indices, seed=seed),
9696
u,
9797
coefficient,
98-
done=AbsoluteStandardError(1e-4) | MaxUpdates(1000),
98+
done=HistoryDeviation(50, 1e-3) | MaxUpdates(1000),
9999
skip_converged=True,
100100
n_jobs=n_jobs,
101101
batch_size=batch_size,
@@ -128,7 +128,7 @@ def test_banzhaf(
128128
parallel_config: ParallelConfig,
129129
):
130130
u, exact_values = analytic_banzhaf
131-
criterion = AbsoluteStandardError(1e-4, burn_in=32) | MaxUpdates(1000)
131+
criterion = HistoryDeviation(50, 1e-3) | MaxUpdates(1000)
132132
values = compute_generic_semivalues(
133133
sampler(u.data.indices),
134134
u,

0 commit comments

Comments
 (0)