Skip to content

Commit f682de8

Browse files
committed
fixed one error; added joblib; all tests pass
1 parent ae0f550 commit f682de8

File tree

5 files changed

+114
-45
lines changed

5 files changed

+114
-45
lines changed

ls_spa/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from .ls_spa import (
44
ShapleyResults,
5-
SizeIncompatible,
65
SizeIncompatibleError,
76
error_estimates,
87
ls_spa,

ls_spa/ls_spa.py

Lines changed: 99 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
import pandas as pd
2929
import scipy as sp
30+
from joblib import Parallel, delayed
3031
from numpy import random
3132

3233
from ls_spa.qmc import argsort_samples, permutohedron_samples
@@ -226,6 +227,53 @@ def process_perms(
226227
return perms
227228

228229

230+
def _compute_lift(
231+
perm: np.ndarray,
232+
X_train_tilde: np.ndarray,
233+
X_test_tilde: np.ndarray,
234+
y_train_tilde: np.ndarray,
235+
y_test_tilde: np.ndarray,
236+
y_test_norm_sq: float,
237+
antithetical: bool,
238+
) -> np.ndarray:
239+
"""Compute the lift for a single permutation.
240+
241+
Args:
242+
perm: The permutation to use.
243+
X_train_tilde: The reduced training data.
244+
X_test_tilde: The reduced test data.
245+
y_train_tilde: The reduced training labels.
246+
y_test_tilde: The reduced test labels.
247+
y_test_norm_sq: The squared norm of the test labels.
248+
antithetical: Whether to use antithetical sampling.
249+
250+
Returns:
251+
The lift vector.
252+
"""
253+
perm_np = np.array(perm)
254+
lift = square_shapley(
255+
X_train_tilde,
256+
X_test_tilde,
257+
y_train_tilde,
258+
y_test_tilde,
259+
y_test_norm_sq,
260+
perm_np,
261+
)
262+
if antithetical:
263+
lift = (
264+
lift
265+
+ square_shapley(
266+
X_train_tilde,
267+
X_test_tilde,
268+
y_train_tilde,
269+
y_test_tilde,
270+
y_test_norm_sq,
271+
perm_np[::-1],
272+
)
273+
) / 2
274+
return lift
275+
276+
229277
def ls_spa(
230278
X_train: np.ndarray | pd.DataFrame,
231279
X_test: np.ndarray | pd.DataFrame,
@@ -239,6 +287,7 @@ def ls_spa(
239287
perms: PERM_TYPE | None = None,
240288
antithetical: bool = True,
241289
return_attribution_history: bool = False,
290+
n_jobs: int = 1,
242291
) -> ShapleyResults:
243292
"""Estimates the Shapley attribution for a least-squares problem.
244293
@@ -256,6 +305,8 @@ def ls_spa(
256305
generated randomly.
257306
antithetical (bool): Whether to use antithetical sampling.
258307
return_attribution_history (bool): Whether to return the attribution history.
308+
n_jobs (int): The number of parallel jobs to use. Use -1 to use all available
309+
CPU cores. Default is 1 (sequential processing).
259310
260311
Returns:
261312
A ShapleyResults object containing the Shapley attribution, the
@@ -277,7 +328,10 @@ def ls_spa(
277328
antithetical = False
278329

279330
perms = process_perms(p, rng, max_samples, perms)
280-
max_samples = len(perms)
331+
332+
# Convert to list for batching (handles iterators like it.permutations)
333+
perms_list = list(perms)
334+
max_samples = len(perms_list)
281335

282336
# Compute the reduction
283337
y_test_norm_sq = np.linalg.norm(y_test) ** 2
@@ -289,72 +343,73 @@ def ls_spa(
289343
reg,
290344
)
291345

292-
# Iterate over the permutations to compute lifts
346+
# Initialize accumulators for the Shapley attribution
293347
shapley_values = np.zeros(p)
294348
attribution_cov = np.zeros((p, p))
295349
attribution_errors = np.full(p, 0.0)
296350
overall_error = 0.0
297351
error_history = np.zeros(0)
298352
attribution_history = np.zeros((0, p)) if return_attribution_history else None
299353

354+
# Iterate over permutations in batches
300355
i = 0
301-
for perm in perms:
302-
i += 1
303-
do_mini_batch = True
304-
305-
# Compute the lift
306-
perm_np = np.array(perm)
307-
lift = square_shapley(
308-
X_train_tilde,
309-
X_test_tilde,
310-
y_train_tilde,
311-
y_test_tilde,
312-
y_test_norm_sq,
313-
perm_np,
314-
)
315-
if antithetical:
316-
lift = (
317-
lift
318-
+ square_shapley(
356+
for batch_start in range(0, max_samples, batch_size):
357+
batch_end = min(batch_start + batch_size, max_samples)
358+
batch_perms = perms_list[batch_start:batch_end]
359+
360+
# Compute lifts for the batch (parallel or sequential)
361+
if n_jobs == 1:
362+
lifts = [
363+
_compute_lift(
364+
perm,
319365
X_train_tilde,
320366
X_test_tilde,
321367
y_train_tilde,
322368
y_test_tilde,
323369
y_test_norm_sq,
324-
perm_np[::-1],
370+
antithetical,
325371
)
326-
) / 2
327-
328-
# Update the mean and biased sample covariance
329-
attribution_cov = merge_sample_cov(
330-
shapley_values,
331-
lift,
332-
attribution_cov,
333-
np.zeros((p, p)),
334-
i - 1,
335-
1,
336-
)
337-
shapley_values = merge_sample_mean(shapley_values, lift, i - 1, 1)
338-
if return_attribution_history:
339-
attribution_history = np.vstack((attribution_history, shapley_values))
372+
for perm in batch_perms
373+
]
374+
else:
375+
lifts = Parallel(n_jobs=n_jobs)(
376+
delayed(_compute_lift)(
377+
perm,
378+
X_train_tilde,
379+
X_test_tilde,
380+
y_train_tilde,
381+
y_test_tilde,
382+
y_test_norm_sq,
383+
antithetical,
384+
)
385+
for perm in batch_perms
386+
)
340387

341-
# Update the errors
342-
if (i % batch_size == 0 or i == max_samples - 1) and p >= MAX_FEAS_EXACT_FEATS:
388+
# Aggregate lifts sequentially (updates running mean and covariance)
389+
for lift in lifts:
390+
i += 1
391+
attribution_cov = merge_sample_cov(
392+
shapley_values,
393+
lift,
394+
attribution_cov,
395+
np.zeros((p, p)),
396+
i - 1,
397+
1,
398+
)
399+
shapley_values = merge_sample_mean(shapley_values, lift, i - 1, 1)
400+
if return_attribution_history:
401+
attribution_history = np.vstack((attribution_history, shapley_values))
402+
403+
# Update errors after each batch
404+
if p >= MAX_FEAS_EXACT_FEATS and i > 1:
343405
unbiased_cov = attribution_cov * i / (i - 1)
344406
attribution_errors, overall_error = error_estimates(rng, unbiased_cov / i)
345407
error_history = np.append(error_history, overall_error)
346-
do_mini_batch = False
347408

348409
# Check the stopping criterion
349410
if overall_error < tolerance:
350411
break
351412

352-
# Last mini-batch
353-
if p >= MAX_FEAS_EXACT_FEATS and do_mini_batch:
354-
unbiased_cov = attribution_cov * i / (i - 1)
355-
attribution_errors, overall_error = error_estimates(rng, unbiased_cov / i)
356-
error_history = np.append(error_history, overall_error)
357-
358413
# Compute auxiliary information
359414
theta = np.linalg.lstsq(X_train_tilde, y_train_tilde, rcond=None)[0]
360415
r_squared = (

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies = [
1313
"numpy>=2.3.4,<3",
1414
"scipy>=1.16.2,<2",
1515
"pandas>=2.3.3,<3",
16+
"joblib>=1.4.0,<2",
1617
]
1718

1819
[dependency-groups]

tests/test_ls_spa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def test_correctness_easy(self) -> None:
186186
self.y_test_easy,
187187
max_samples=256 * 256,
188188
batch_size=256,
189+
n_jobs=-1,
189190
)
190191
np.testing.assert_almost_equal(proposal, easy_results.attribution)
191192

uv.lock

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)