2727import numpy as np
2828import pandas as pd
2929import scipy as sp
30+ from joblib import Parallel , delayed
3031from numpy import random
3132
3233from ls_spa .qmc import argsort_samples , permutohedron_samples
@@ -226,6 +227,53 @@ def process_perms(
226227 return perms
227228
228229
230+ def _compute_lift (
231+ perm : np .ndarray ,
232+ X_train_tilde : np .ndarray ,
233+ X_test_tilde : np .ndarray ,
234+ y_train_tilde : np .ndarray ,
235+ y_test_tilde : np .ndarray ,
236+ y_test_norm_sq : float ,
237+ antithetical : bool ,
238+ ) -> np .ndarray :
239+ """Compute the lift for a single permutation.
240+
241+ Args:
242+ perm: The permutation to use.
243+ X_train_tilde: The reduced training data.
244+ X_test_tilde: The reduced test data.
245+ y_train_tilde: The reduced training labels.
246+ y_test_tilde: The reduced test labels.
247+ y_test_norm_sq: The squared norm of the test labels.
248+ antithetical: Whether to use antithetical sampling.
249+
250+ Returns:
251+ The lift vector.
252+ """
253+ perm_np = np .array (perm )
254+ lift = square_shapley (
255+ X_train_tilde ,
256+ X_test_tilde ,
257+ y_train_tilde ,
258+ y_test_tilde ,
259+ y_test_norm_sq ,
260+ perm_np ,
261+ )
262+ if antithetical :
263+ lift = (
264+ lift
265+ + square_shapley (
266+ X_train_tilde ,
267+ X_test_tilde ,
268+ y_train_tilde ,
269+ y_test_tilde ,
270+ y_test_norm_sq ,
271+ perm_np [::- 1 ],
272+ )
273+ ) / 2
274+ return lift
275+
276+
229277def ls_spa (
230278 X_train : np .ndarray | pd .DataFrame ,
231279 X_test : np .ndarray | pd .DataFrame ,
@@ -239,6 +287,7 @@ def ls_spa(
239287 perms : PERM_TYPE | None = None ,
240288 antithetical : bool = True ,
241289 return_attribution_history : bool = False ,
290+ n_jobs : int = 1 ,
242291) -> ShapleyResults :
243292 """Estimates the Shapley attribution for a least-squares problem.
244293
@@ -256,6 +305,8 @@ def ls_spa(
256305 generated randomly.
257306 antithetical (bool): Whether to use antithetical sampling.
258307 return_attribution_history (bool): Whether to return the attribution history.
308+ n_jobs (int): The number of parallel jobs to use. Use -1 to use all available
309+ CPU cores. Default is 1 (sequential processing).
259310
260311 Returns:
261312 A ShapleyResults object containing the Shapley attribution, the
@@ -277,7 +328,10 @@ def ls_spa(
277328 antithetical = False
278329
279330 perms = process_perms (p , rng , max_samples , perms )
280- max_samples = len (perms )
331+
332+ # Convert to list for batching (handles iterators like it.permutations)
333+ perms_list = list (perms )
334+ max_samples = len (perms_list )
281335
282336 # Compute the reduction
283337 y_test_norm_sq = np .linalg .norm (y_test ) ** 2
@@ -289,72 +343,73 @@ def ls_spa(
289343 reg ,
290344 )
291345
292- # Iterate over the permutations to compute lifts
346+ # Initialize accumulators for the Shapley attribution
293347 shapley_values = np .zeros (p )
294348 attribution_cov = np .zeros ((p , p ))
295349 attribution_errors = np .full (p , 0.0 )
296350 overall_error = 0.0
297351 error_history = np .zeros (0 )
298352 attribution_history = np .zeros ((0 , p )) if return_attribution_history else None
299353
354+ # Iterate over permutations in batches
300355 i = 0
301- for perm in perms :
302- i += 1
303- do_mini_batch = True
304-
305- # Compute the lift
306- perm_np = np .array (perm )
307- lift = square_shapley (
308- X_train_tilde ,
309- X_test_tilde ,
310- y_train_tilde ,
311- y_test_tilde ,
312- y_test_norm_sq ,
313- perm_np ,
314- )
315- if antithetical :
316- lift = (
317- lift
318- + square_shapley (
356+ for batch_start in range (0 , max_samples , batch_size ):
357+ batch_end = min (batch_start + batch_size , max_samples )
358+ batch_perms = perms_list [batch_start :batch_end ]
359+
360+ # Compute lifts for the batch (parallel or sequential)
361+ if n_jobs == 1 :
362+ lifts = [
363+ _compute_lift (
364+ perm ,
319365 X_train_tilde ,
320366 X_test_tilde ,
321367 y_train_tilde ,
322368 y_test_tilde ,
323369 y_test_norm_sq ,
324- perm_np [:: - 1 ] ,
370+ antithetical ,
325371 )
326- ) / 2
327-
328- # Update the mean and biased sample covariance
329- attribution_cov = merge_sample_cov (
330- shapley_values ,
331- lift ,
332- attribution_cov ,
333- np .zeros ((p , p )),
334- i - 1 ,
335- 1 ,
336- )
337- shapley_values = merge_sample_mean (shapley_values , lift , i - 1 , 1 )
338- if return_attribution_history :
339- attribution_history = np .vstack ((attribution_history , shapley_values ))
372+ for perm in batch_perms
373+ ]
374+ else :
375+ lifts = Parallel (n_jobs = n_jobs )(
376+ delayed (_compute_lift )(
377+ perm ,
378+ X_train_tilde ,
379+ X_test_tilde ,
380+ y_train_tilde ,
381+ y_test_tilde ,
382+ y_test_norm_sq ,
383+ antithetical ,
384+ )
385+ for perm in batch_perms
386+ )
340387
341- # Update the errors
342- if (i % batch_size == 0 or i == max_samples - 1 ) and p >= MAX_FEAS_EXACT_FEATS :
388+ # Aggregate lifts sequentially (updates running mean and covariance)
389+ for lift in lifts :
390+ i += 1
391+ attribution_cov = merge_sample_cov (
392+ shapley_values ,
393+ lift ,
394+ attribution_cov ,
395+ np .zeros ((p , p )),
396+ i - 1 ,
397+ 1 ,
398+ )
399+ shapley_values = merge_sample_mean (shapley_values , lift , i - 1 , 1 )
400+ if return_attribution_history :
401+ attribution_history = np .vstack ((attribution_history , shapley_values ))
402+
403+ # Update errors after each batch
404+ if p >= MAX_FEAS_EXACT_FEATS and i > 1 :
343405 unbiased_cov = attribution_cov * i / (i - 1 )
344406 attribution_errors , overall_error = error_estimates (rng , unbiased_cov / i )
345407 error_history = np .append (error_history , overall_error )
346- do_mini_batch = False
347408
348409 # Check the stopping criterion
349410 if overall_error < tolerance :
350411 break
351412
352- # Last mini-batch
353- if p >= MAX_FEAS_EXACT_FEATS and do_mini_batch :
354- unbiased_cov = attribution_cov * i / (i - 1 )
355- attribution_errors , overall_error = error_estimates (rng , unbiased_cov / i )
356- error_history = np .append (error_history , overall_error )
357-
358413 # Compute auxiliary information
359414 theta = np .linalg .lstsq (X_train_tilde , y_train_tilde , rcond = None )[0 ]
360415 r_squared = (
0 commit comments