@@ -33,6 +33,7 @@ to the top of your Python file.
3333- ` numpy `
3434- ` scipy `
3535- ` pandas `
36+ - ` joblib `
3637
3738Optional dependencies are
3839
@@ -51,14 +52,14 @@ Shapley attribution of the out-of-sample $R^2$ on your data by executing
5152attrs = ls_spa(X_train, X_test, y_train, y_test).attribution
5253```
5354
54- ` attrs ` will be a JAX vector containing the Shapley values of your features.
55+ ` attrs ` will be a NumPy array containing the Shapley values of your features.
5556The ` ls_spa ` function computes Shapley values for the given data using
5657the LS-SPA method described in the companion paper. It takes arguments:
5758
58- - ` X_train ` : Training feature matrix.
59- - ` X_test ` : Testing feature matrix.
60- - ` y_train ` : Training response vector.
61- - ` y_test ` : Testing response vector.
59+ - ` X_train ` : Training feature matrix (NumPy array or pandas DataFrame) .
60+ - ` X_test ` : Testing feature matrix (NumPy array or pandas DataFrame) .
61+ - ` y_train ` : Training response vector (NumPy array or pandas Series) .
62+ - ` y_test ` : Testing response vector (NumPy array or pandas Series) .
6263
6364## Hello world
6465
@@ -104,28 +105,32 @@ on the same data.
104105
105106` ls_spa ` takes the optional arguments:
106107
107- - ` reg ` : Regularization parameter (Default ` 0 ` ).
108- - ` method ` : Permutation sampling method. Options include ` 'random' ` ,
109- ` 'permutohedron' ` , ` 'argsort' ` , and ` 'exact' ` . If ` None ` , ` 'argsort' ` is used
110- if the number of features is greater than 10; otherwise, ` 'exact' ` is used.
111- - ` batch_size ` : Number of permutations in each batch (Default ` 2**7 ` ).
112- - ` num_batches ` : Maximum number of batches (Default ` 2**7 ` ).
113- - ` tolerance ` : Convergence tolerance for the Shapley values (Default ` 1e-2 ` ).
108+ - ` reg ` : Ridge regularization parameter (Default ` 0.0 ` ).
109+ - ` max_samples ` : Maximum number of feature permutations to sample (Default ` 8192 ` ).
110+ - ` batch_size ` : Number of permutations to process per batch (Default ` 256 ` ).
111+ - ` tolerance ` : Stopping criterion for estimation error (Default ` 0.01 ` ).
114112- ` seed ` : Seed for random number generation (Default ` 42 ` ).
115- - ` return_history ` : Flag to determine whether to return the history of error estimates and attributions for each feature chain (Default ` False ` ).
113+ - ` perms ` : Permutation sampling method (Default ` None ` ). Options include:
114+ - ` None ` : Auto-select ` "exact" ` for p < 9 features, otherwise ` "random" `
115+ - ` "exact" ` : Enumerate all permutations (only feasible for p < 9)
116+ - ` "random" ` : Uniformly random permutations
117+ - ` "argsort" ` : Quasi-Monte Carlo permutations using argsort
118+ - ` "permutohedron" ` : Quasi-Monte Carlo permutations from permutohedron lattice
119+ - Custom array or tuple of permutations
120+ - ` antithetical ` : Use antithetical (paired) sampling for variance reduction (Default ` True ` ).
121+ - ` return_attribution_history ` : Return convergence history of attributions (Default ` False ` ).
122+ - ` n_jobs ` : Number of parallel jobs; use ` -1 ` for all CPU cores (Default ` 1 ` ).
116123
117124` ls_spa ` returns a ` ShapleyResults ` object. The ` ShapleyResults ` object
118125has the fields:
119126
120127- ` attribution ` : Array of Shapley values for each feature.
121- - ` attribution_history ` : Array of Shapley values for each iteration.
122- ` None ` if ` return_history=False ` in ` ls_spa ` call.
123- - ` theta ` : Array of regression coefficients.
124- - ` overall_error ` : Mean absolute error of the Shapley values.
125- - ` error_history ` : Array of mean absolute errors for each iteration.
126- ` None ` if ` return_history=False ` in ` ls_spa ` call.
127- - ` attribution_errors ` : Array of absolute errors for each feature.
128- - ` r_squared ` : Out-of-sample R-squared statistic of the regression.
128+ - ` theta ` : Array of regression coefficients with all features.
129+ - ` r_squared ` : Out-of-sample R² with all features.
130+ - ` overall_error ` : Estimated error (95th percentile L2 norm) in Shapley attribution vector.
131+ - ` attribution_errors ` : Array of estimated errors for each feature's attribution.
132+ - ` error_history ` : Array of error estimates after each batch. ` None ` if using exact computation.
133+ - ` attribution_history ` : Array of attribution estimates over time. ` None ` if ` return_attribution_history=False ` .
129134
130135## Citing
131136
0 commit comments