Skip to content

Commit 203c555

Browse files
committed
Iterations per job has different semantics for comb MCShap
1 parent 25295f3 commit 203c555

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

src/pydvl/shapley/montecarlo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ def combinatorial_montecarlo_shapley(
322322
"""
323323
parallel_backend = init_parallel_backend(config)
324324
u_id = parallel_backend.put(u)
325-
iterations_per_job = max_iterations // n_jobs
326325

327326
def reducer(results_it: Iterable[MonteCarloResults]) -> MonteCarloResults:
328327
values = np.zeros(len(u.data))
@@ -335,13 +334,14 @@ def reducer(results_it: Iterable[MonteCarloResults]) -> MonteCarloResults:
335334
stderr += std
336335
return MonteCarloResults(values=values, stderr=stderr)
337336

337+
# FIXME? max_iterations has different semantics in permutation-based methods
338338
map_reduce_job: MapReduceJob["NDArray", MonteCarloResults] = MapReduceJob(
339339
map_func=_combinatorial_montecarlo_shapley,
340340
reduce_func=reducer,
341341
map_kwargs=dict(
342342
u=u_id,
343343
dist=dist,
344-
max_iterations=iterations_per_job,
344+
max_iterations=max_iterations,
345345
progress=progress,
346346
),
347347
chunkify_inputs=True,

tests/shapley/test_montecarlo.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,6 @@ def test_analytic_montecarlo_shapley(
5353
def test_hoeffding_bound_montecarlo(
5454
analytic_shapley, fun, n_jobs, delta, eps, tolerate
5555
):
56-
"""FIXME: This test passes but there are several unclear points.
57-
For example, map_reduce is called with num_jobs=num_runs. Is this correct?
58-
If I put num_jobs=jobs_per_run, map_reduce encounters errors since a utility
59-
is passed.
60-
Before coming back to this test, fix map_reduce interface."""
6156
u, exact_values = analytic_shapley
6257

6358
max_iterations = lower_bound_hoeffding(delta=delta, eps=eps, score_range=1)
@@ -93,8 +88,8 @@ def test_hoeffding_bound_montecarlo(
9388
12,
9489
combinatorial_montecarlo_shapley,
9590
"explained_variance",
96-
0.5,
97-
5000,
91+
0.2,
92+
2**11,
9893
),
9994
],
10095
)
@@ -116,7 +111,7 @@ def test_linear_montecarlo_shapley(
116111
)
117112
values, _ = fun(
118113
linear_utility,
119-
max_iterations=max_iterations,
114+
max_iterations=int(max_iterations),
120115
progress=False,
121116
n_jobs=n_jobs,
122117
)

0 commit comments

Comments
 (0)