Skip to content

Commit 7075939

Browse files
committed
misc cleaning
1 parent b2fd0f7 commit 7075939

File tree

5 files changed

+24
-10
lines changed

5 files changed

+24
-10
lines changed

.github/workflows/pre-commit.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ jobs:
1616
- name: Cache uv dependencies
1717
uses: actions/cache@v4
1818
with:
19-
path: .venv
20-
key: ${{ runner.os }}-uv-${{ hashFiles('uv.lock') }}
19+
path: |
20+
.venv
21+
~/.cache/uv
22+
key: ${{ runner.os }}-uv-${{ hashFiles('uv.lock', 'pyproject.toml') }}
2123
restore-keys: |
22-
${{ runner.os }}-uv-
24+
${{ runner.os }}-uv--
2325
- name: Install dependencies
2426
run: uv sync
2527
- uses: pre-commit/[email protected]

.github/workflows/release.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ jobs:
1616
- name: Cache uv dependencies
1717
uses: actions/cache@v4
1818
with:
19-
path: .venv
19+
path: |
20+
.venv
21+
~/.cache/uv
2022
key: ${{ runner.os }}-uv-${{ hashFiles('uv.lock', 'pyproject.toml') }}
2123
restore-keys: |
2224
${{ runner.os }}-uv-

ls_spa/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .ls_spa import (
44
ShapleyResults,
5+
SizeIncompatible,
56
SizeIncompatibleError,
67
error_estimates,
78
ls_spa,
@@ -13,6 +14,7 @@
1314

1415
__all__ = [
1516
"ShapleyResults",
17+
"SizeIncompatible",
1618
"SizeIncompatibleError",
1719
"error_estimates",
1820
"ls_spa",

ls_spa/ls_spa.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from numpy import random
3030

3131
# The maximum number of features for which we can display the full attribution.
32-
CAN_DISPLAY_FULL_ATTR = 5
32+
MAX_ATTR_DISP = 5
3333

3434
# The maximum number of features for which we can feasibly compute the exact Shapley values.
3535
MAX_FEAS_EXACT_FEATS = 9
@@ -52,15 +52,19 @@ def __repr__(self) -> str:
5252
attr_str = ""
5353
coefs_str = ""
5454

55-
if len(self.attribution) <= CAN_DISPLAY_FULL_ATTR:
55+
if len(self.attribution) <= MAX_ATTR_DISP:
5656
attr_str = "(" + "".join(f"{a:.2f}, " for a in self.attribution.flatten())[:-2] + ")"
5757
coefs_str = "(" + "".join(f"{c:.2f}, " for c in self.theta.flatten())[:-2] + ")"
5858
else:
5959
attr_str = (
60-
"(" + "".join(f"{a:.2f}, " for a in self.attribution.flatten()[:5])[:-2] + ", ...)"
60+
"("
61+
+ "".join(f"{a:.2f}, " for a in self.attribution.flatten()[:MAX_ATTR_DISP])[:-2]
62+
+ ", ...)"
6163
)
6264
coefs_str = (
63-
"(" + "".join(f"{c:.2f}, " for c in self.theta.flatten()[:5])[:-2] + ", ...)"
65+
"("
66+
+ "".join(f"{c:.2f}, " for c in self.theta.flatten()[:MAX_ATTR_DISP])[:-2]
67+
+ ", ...)"
6468
)
6569

6670
return f"""
@@ -83,6 +87,11 @@ def __init__(self, message: str) -> None:
8387
super().__init__(self.message)
8488

8589

90+
# TODO(ndevanathan): remove this in the next major update
91+
# This is here for backwards compatibility
92+
SizeIncompatible = SizeIncompatibleError
93+
94+
8695
def validate_data(
8796
X_train: np.ndarray,
8897
X_test: np.ndarray,
@@ -410,7 +419,7 @@ def error_estimates(rng: random.Generator, cov: np.ndarray) -> tuple[np.ndarray,
410419
p = cov.shape[0]
411420
try:
412421
sample_diffs = rng.multivariate_normal(np.zeros(p), cov, size=2**10, method="cholesky")
413-
except: # noqa: E722
422+
except (np.linalg.LinAlgError, ValueError):
414423
sample_diffs = rng.multivariate_normal(np.zeros(p), cov, size=2**10, method="svd")
415424
abs_diffs = np.abs(sample_diffs)
416425
norms = np.linalg.norm(sample_diffs, axis=1)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ extend-exclude = [
7777
"notebooks",
7878
"paper",
7979
"data",
80-
".venv",
8180
"__marimo__",
8281
"__pycache__",
8382
"build",

0 commit comments

Comments
 (0)