Skip to content

Commit de00cae

Browse files
authored
fix cli tests (#314)
* move to pyproject.toml * delete old setup.py etc * ruff format * ruff check fixes * limit deps * missing imports * squeeze skorch shap values tensors * use ruff as gh actions linter * update gh actions workflows * install --system * fix ci script installation * fix cli tests * add explainer.yaml to test assets * further cli test fixes
1 parent f919915 commit de00cae

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

explainerdashboard/explainers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,8 @@ def to_yaml(
507507
modelfile=modelfile,
508508
datafile=datafile,
509509
explainerfile=explainerfile,
510-
data_target=self.target,
511-
data_index=self.idxs.name,
510+
data_target=target_col or self.target,
511+
data_index=index_col or self.idxs.name,
512512
explainer_type="classifier" if self.is_classifier else "regression",
513513
dashboard_yaml=dashboard_yaml,
514514
params=self._params_dict,

tests/hub/test_hub_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def test_explainerhub_cli_help(explainer_hub_dump_folder, script_runner):
77
ret = script_runner.run(
8-
["explainerhub", " --help"], cwd=str(explainer_hub_dump_folder)
8+
["explainerhub", "--help"], cwd=str(explainer_hub_dump_folder)
99
)
1010
assert ret.success
1111
assert ret.stderr == ""

tests/test_cli.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pickle
12
import pytest
23
from pathlib import Path
34

@@ -24,6 +25,7 @@ def generate_assets():
2425
cats=[{"Gender": ["Sex_female", "Sex_male", "Sex_nan"]}, "Deck", "Embarked"],
2526
labels=["Not survived", "Survived"],
2627
)
28+
2729

2830
dashboard = ExplainerDashboard(
2931
explainer,
@@ -36,18 +38,30 @@ def generate_assets():
3638
)
3739

3840
pkl_dir = Path.cwd() / "tests" / "test_assets"
39-
explainer.to_yaml(pkl_dir / "explainer.yaml")
41+
pkl_dir.mkdir(parents=True, exist_ok=True)
42+
explainer_joblib_path = pkl_dir / "explainer.joblib"
43+
model_path = pkl_dir / "model.pkl"
44+
45+
explainer_yaml_path = pkl_dir / "explainer.yaml"
46+
dashboard_yaml_path = pkl_dir / "dashboard.yaml"
47+
pickle.dump(model, open(model_path, "wb"))
48+
explainer.to_yaml(explainer_yaml_path, index_col="Name", target_col="Survival")
49+
explainer.dump(explainer_joblib_path)
4050
dashboard.to_yaml(
41-
pkl_dir / "dashboard.yaml",
42-
explainerfile=str(pkl_dir / "explainer.joblib"),
51+
dashboard_yaml_path,
52+
explainerfile=str(explainer_joblib_path),
4353
dump_explainer=True,
4454
)
45-
return None
55+
yield
4656

57+
explainer_joblib_path.unlink()
58+
explainer_yaml_path.unlink()
59+
dashboard_yaml_path.unlink()
60+
model_path.unlink()
4761

4862
def test_explainerdashboard_cli_help(generate_assets, script_runner):
4963
ret = script_runner.run(
50-
["explainerdashboard", " --help"],
64+
["explainerdashboard", "--help"],
5165
cwd=str(Path().cwd() / "tests" / "test_assets"),
5266
)
5367
assert ret.success
@@ -56,7 +70,7 @@ def test_explainerdashboard_cli_help(generate_assets, script_runner):
5670

5771
def test_explainerdashboard_cli_explainer(generate_assets, script_runner):
5872
ret = script_runner.run(
59-
["explainerdashboard", " test explainer.joblib"],
73+
["explainerdashboard", "test", "explainer.joblib"],
6074
cwd=str(Path().cwd() / "tests" / "test_assets"),
6175
)
6276
assert ret.success
@@ -65,7 +79,7 @@ def test_explainerdashboard_cli_explainer(generate_assets, script_runner):
6579

6680
def test_explainerdashboard_cli_yaml(generate_assets, script_runner):
6781
ret = script_runner.run(
68-
["explainerdashboard", " test dashboard.yaml"],
82+
["explainerdashboard", "test", "dashboard.yaml"],
6983
cwd=str(Path().cwd() / "tests" / "test_assets"),
7084
)
7185
assert ret.success
@@ -74,7 +88,7 @@ def test_explainerdashboard_cli_yaml(generate_assets, script_runner):
7488

7589
def test_explainerdashboard_cli_build(generate_assets, script_runner):
7690
ret = script_runner.run(
77-
["explainerdashboard", " build"],
91+
["explainerdashboard", "build", "explainer.yaml"],
7892
cwd=str(Path().cwd() / "tests" / "test_assets"),
7993
)
8094
assert ret.success
@@ -83,7 +97,7 @@ def test_explainerdashboard_cli_build(generate_assets, script_runner):
8397

8498
def test_explainerdashboard_cli_build_explainer(generate_assets, script_runner):
8599
ret = script_runner.run(
86-
["explainerdashboard", " build explainer.yaml"],
100+
["explainerdashboard", "build", "explainer.yaml"],
87101
cwd=str(Path().cwd() / "tests" / "test_assets"),
88102
)
89103
assert ret.success
@@ -94,7 +108,7 @@ def test_explainerdashboard_cli_build_explainer_dashboard(
94108
generate_assets, script_runner
95109
):
96110
ret = script_runner.run(
97-
["explainerdashboard", " build explainer.yaml dashboard.yaml"],
111+
["explainerdashboard", "build", "explainer.yaml", "dashboard.yaml"],
98112
cwd=str(Path().cwd() / "tests" / "test_assets"),
99113
)
100114
assert ret.success

0 commit comments

Comments
 (0)