Skip to content

Commit 257de62

Browse files
committed
Fix construction of samplers in some tests
1 parent 1a8b9cb commit 257de62

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

tests/valuation/methods/test_montecarlo_shapley.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def shapley_methods(fudge_factor: int):
5959
AntitheticSampler,
6060
{"seed": lambda seed: seed},
6161
ShapleyValuation,
62-
{"is_done": (MinUpdates, {"n_updates": fudge_factor // 2})},
62+
{"is_done": (MinUpdates, {"n_updates": fudge_factor})},
6363
),
6464
(
6565
MSRSampler,
@@ -291,14 +291,15 @@ def test_hoeffding_bound_montecarlo(
291291
@pytest.mark.parametrize(
292292
"sampler_cls, sampler_kwargs, valuation_cls, valuation_kwargs", shapley_methods(500)
293293
)
294-
def test_linear_montecarlo(
294+
def test_linear_montecarlo_with_outlier(
295295
linear_dataset,
296296
linear_shapley,
297297
n_jobs: int,
298298
sampler_cls: Type,
299299
sampler_kwargs: dict[str, Any],
300300
valuation_cls: Type,
301301
valuation_kwargs: dict[str, Any],
302+
seed: int,
302303
):
303304
"""Tests whether valuation methods are able to detect an obvious outlier.
304305
@@ -318,11 +319,13 @@ def test_linear_montecarlo(
318319
# train.data().y[outlier_idx] -= 100
319320

320321
if sampler_cls is not None:
321-
valuation_kwargs["sampler"] = recursive_make(sampler_cls, sampler_kwargs)
322+
valuation_kwargs["sampler"] = recursive_make(
323+
sampler_cls, sampler_kwargs, seed=seed, lower_bound=0, upper_bound=None
324+
)
322325

323326
valuation_kwargs["utility"] = utility
324327
valuation_kwargs["progress"] = False
325-
valuation = recursive_make(valuation_cls, valuation_kwargs)
328+
valuation = recursive_make(valuation_cls, valuation_kwargs, seed=seed)
326329

327330
with parallel_config(n_jobs=n_jobs):
328331
valuation.fit(train)

tests/valuation/samplers/test_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def random_samplers(proper: bool = False):
111111
{
112112
"outer_sampling_strategy": (
113113
UniformOwenStrategy,
114-
{"n_samples_outer": lambda n=2: n},
114+
{"n_samples_outer": lambda n=200: n},
115115
),
116116
"index_iteration": NoIndexIteration,
117117
},
@@ -121,7 +121,7 @@ def random_samplers(proper: bool = False):
121121
{
122122
"outer_sampling_strategy": (
123123
UniformOwenStrategy,
124-
{"n_samples_outer": lambda n=2: n},
124+
{"n_samples_outer": lambda n=200: n},
125125
),
126126
"index_iteration": NoIndexIteration,
127127
},
@@ -250,7 +250,7 @@ def random_samplers(proper: bool = False):
250250
{
251251
"outer_sampling_strategy": (
252252
UniformOwenStrategy,
253-
{"n_samples_outer": lambda n=32: n, "seed": lambda seed: seed},
253+
{"n_samples_outer": lambda n=200: n, "seed": lambda seed: seed},
254254
),
255255
"index_iteration": FiniteSequentialIndexIteration,
256256
},

0 commit comments

Comments
 (0)