Skip to content

Commit 1bc709c

Browse files
authored
Merge pull request #7 from cvxgrp/better_sampling
Better sampling
2 parents 2f16751 + a0ec4a4 commit 1bc709c

File tree

7 files changed

+592
-144
lines changed

7 files changed

+592
-144
lines changed

README.md

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ to the top of your Python file.
3333
- `numpy`
3434
- `scipy`
3535
- `pandas`
36+
- `joblib`
3637

3738
Optional dependencies are
3839

@@ -51,14 +52,14 @@ Shapley attribution of the out-of-sample $R^2$ on your data by executing
5152
attrs = ls_spa(X_train, X_test, y_train, y_test).attribution
5253
```
5354

54-
`attrs` will be a JAX vector containing the Shapley values of your features.
55+
`attrs` will be a NumPy array containing the Shapley values of your features.
5556
The `ls_spa` function computes Shapley values for the given data using
5657
the LS-SPA method described in the companion paper. It takes arguments:
5758

58-
- `X_train`: Training feature matrix.
59-
- `X_test`: Testing feature matrix.
60-
- `y_train`: Training response vector.
61-
- `y_test`: Testing response vector.
59+
- `X_train`: Training feature matrix (NumPy array or pandas DataFrame).
60+
- `X_test`: Testing feature matrix (NumPy array or pandas DataFrame).
61+
- `y_train`: Training response vector (NumPy array or pandas Series).
62+
- `y_test`: Testing response vector (NumPy array or pandas Series).
6263

6364
## Hello world
6465

@@ -104,28 +105,32 @@ on the same data.
104105

105106
`ls_spa` takes the optional arguments:
106107

107-
- `reg`: Regularization parameter (Default `0`).
108-
- `method`: Permutation sampling method. Options include `'random'`,
109-
`'permutohedron'`, `'argsort'`, and `'exact'`. If `None`, `'argsort'` is used
110-
if the number of features is greater than 10; otherwise, `'exact'` is used.
111-
- `batch_size`: Number of permutations in each batch (Default `2**7`).
112-
- `num_batches`: Maximum number of batches (Default `2**7`).
113-
- `tolerance`: Convergence tolerance for the Shapley values (Default `1e-2`).
108+
- `reg`: Ridge regularization parameter (Default `0.0`).
109+
- `max_samples`: Maximum number of feature permutations to sample (Default `8192`).
110+
- `batch_size`: Number of permutations to process per batch (Default `256`).
111+
- `tolerance`: Stopping criterion for estimation error (Default `0.01`).
114112
- `seed`: Seed for random number generation (Default `42`).
115-
- `return_history`: Flag to determine whether to return the history of error estimates and attributions for each feature chain (Default `False`).
113+
- `perms`: Permutation sampling method (Default `None`). Options include:
114+
- `None`: Auto-select `"exact"` for p < 9 features, otherwise `"random"`
115+
- `"exact"`: Enumerate all permutations (only feasible for p < 9)
116+
- `"random"`: Uniformly random permutations
117+
- `"argsort"`: Quasi-Monte Carlo permutations using argsort
118+
- `"permutohedron"`: Quasi-Monte Carlo permutations from permutohedron lattice
119+
- Custom array or tuple of permutations
120+
- `antithetical`: Use antithetical (paired) sampling for variance reduction (Default `True`).
121+
- `return_attribution_history`: Return convergence history of attributions (Default `False`).
122+
- `n_jobs`: Number of parallel jobs; use `-1` for all CPU cores (Default `1`).
116123

117124
`ls_spa` returns a `ShapleyResults` object. The `ShapleyResults` object
118125
has the fields:
119126

120127
- `attribution`: Array of Shapley values for each feature.
121-
- `attribution_history`: Array of Shapley values for each iteration.
122-
`None` if `return_history=False` in `ls_spa` call.
123-
- `theta`: Array of regression coefficients.
124-
- `overall_error`: Mean absolute error of the Shapley values.
125-
- `error_history`: Array of mean absolute errors for each iteration.
126-
`None` if `return_history=False` in `ls_spa` call.
127-
- `attribution_errors`: Array of absolute errors for each feature.
128-
- `r_squared`: Out-of-sample R-squared statistic of the regression.
128+
- `theta`: Array of regression coefficients with all features.
129+
- `r_squared`: Out-of-sample R² with all features.
130+
- `overall_error`: Estimated error (95th percentile L2 norm) in Shapley attribution vector.
131+
- `attribution_errors`: Array of estimated errors for each feature's attribution.
132+
- `error_history`: Array of error estimates after each batch. `None` if using exact computation.
133+
- `attribution_history`: Array of attribution estimates over time. `None` if `return_attribution_history=False`.
129134

130135
## Citing
131136

ls_spa/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from .ls_spa import (
44
ShapleyResults,
5-
SizeIncompatible,
65
SizeIncompatibleError,
76
error_estimates,
87
ls_spa,
@@ -14,7 +13,6 @@
1413

1514
__all__ = [
1615
"ShapleyResults",
17-
"SizeIncompatible",
1816
"SizeIncompatibleError",
1917
"error_estimates",
2018
"ls_spa",

0 commit comments

Comments
 (0)