Skip to content

Commit 9b3ed84

Browse files
committed
add tests that actually fail if wrong things happen
1 parent 0111c8c commit 9b3ed84

File tree

3 files changed

+73
-176
lines changed

3 files changed

+73
-176
lines changed

src/xaiflow/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
"""
77

88
from .mlflow_plugin import XaiflowPlugin
9-
from .report_generator import ReportGenerator
109

1110
__version__ = "0.1.0"
1211
__author__ = "CloudExplain Team"
1312
__email__ = "tobias@cloudexplain.eu"
1413

15-
__all__ = ["CEMLflowPlugin", "ReportGenerator"]
14+
__all__ = ["CEMLflowPlugin"]

src/xaiflow/report_generator.py

Lines changed: 0 additions & 117 deletions
This file was deleted.

tests/test_mlflow_plugin.py

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import pytest
22
from pathlib import Path
3-
from xaiflow.report_generator import ReportGenerator
43
from xaiflow.mlflow_plugin import XaiflowPlugin
54
import numpy as np
65
from sklearn.datasets import fetch_openml
7-
from sklearn.ensemble import RandomForestRegressor
6+
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
7+
from xgboost import XGBClassifier
8+
from catboost import CatBoostClassifier
89
from sklearn.preprocessing import LabelEncoder
910
import shap
11+
from typing import Callable
1012

1113
from playwright.sync_api import sync_playwright
1214

15+
1316
def store_report(html_content, filename="test_report.html"):
1417
"""Helper function to store HTML content to a file."""
1518
outputs_dir = Path("tests/outputs")
@@ -19,6 +22,30 @@ def store_report(html_content, filename="test_report.html"):
1922
f.write(html_content)
2023
return html_path
2124

25+
def save_and_click_canvas_wrapper(func: Callable) -> Callable:
26+
def wrapper(*args, **kwargs):
27+
html_content = func(*args, **kwargs)
28+
html_path = store_report(html_content, func.__name__ + ".html")
29+
with sync_playwright() as p:
30+
browser = p.chromium.launch()
31+
context = browser.new_context(record_video_dir="tests/outputs/videos/")
32+
page = context.new_page()
33+
page.goto(f"file://{html_path.resolve()}")
34+
page.wait_for_selector(".importance-chart-container canvas")
35+
canvas = page.query_selector(".importance-chart-container canvas")
36+
assert canvas is not None, "Canvas not found in importance chart container"
37+
assert canvas.is_visible(), "Canvas is not visible"
38+
box = canvas.bounding_box()
39+
x = box["x"] + box["width"] / 2
40+
y = box["y"] + box["height"] / 2
41+
page.mouse.click(x, y)
42+
page.wait_for_timeout(1000)
43+
# Check if canvas is not empty
44+
is_empty = canvas.evaluate("(node) => {\n const ctx = node.getContext('2d');\n const data = ctx.getImageData(0, 0, node.width, node.height).data;\n const totalPixels = data.length / 4;\n if (totalPixels === 0) return true;\n // Get first pixel color\n const r0 = data[0], g0 = data[1], b0 = data[2], a0 = data[3];\n let sameCount = 0;\n for (let i = 0; i < data.length; i += 4) {\n const r = data[i], g = data[i+1], b = data[i+2], a = data[i+3];\n if (r === r0 && g === g0 && b === b0 && a === a0) {\n sameCount++;\n }\n }\n return (sameCount / totalPixels) >= 0.9;\n}")
45+
assert not is_empty, "Canvas is empty or nearly empty: 99% of pixels have the same color"
46+
return wrapper
47+
48+
@save_and_click_canvas_wrapper
2249
def test_categorical_feature_encodings():
2350
data = fetch_openml(data_id=196, as_frame=True)
2451
X = data.data.copy()[:100]
@@ -59,25 +86,10 @@ def test_categorical_feature_encodings():
5986
feature_encodings=feature_encodings,
6087
feature_names=list(X.columns),
6188
)
62-
html_path = store_report(html_content, "test_categorical_feature_encodings.html")
63-
with sync_playwright() as p:
64-
browser = p.chromium.launch()
65-
context = browser.new_context(record_video_dir="tests/outputs/videos/")
66-
page = context.new_page()
67-
page.goto(f"file://{html_path.resolve()}")
68-
page.wait_for_selector(".importance-chart-container canvas")
69-
canvas = page.query_selector(".importance-chart-container canvas")
70-
assert canvas is not None, "Canvas not found in importance chart container"
71-
assert canvas.is_visible(), "Canvas is not visible"
72-
box = canvas.bounding_box()
73-
x = box["x"] + box["width"] / 2
74-
y = box["y"] + box["height"] / 2
75-
page.mouse.click(x, y)
76-
page.wait_for_timeout(1000) # Wait to ensure video is recorded
77-
context.close() # This will save the video file
78-
browser.close()
89+
return html_content
7990

8091

92+
@save_and_click_canvas_wrapper
8193
def test_no_feature_encodings():
8294
plugin = XaiflowPlugin()
8395
html_content = plugin._generate_html_content(
@@ -87,24 +99,9 @@ def test_no_feature_encodings():
8799
feature_encodings=None,
88100
feature_names=['Feature 1', 'Feature 2', 'Feature 3'],
89101
)
90-
html_path = store_report(html_content, "test_no_feature_encodings.html")
91-
with sync_playwright() as p:
92-
browser = p.chromium.launch()
93-
context = browser.new_context(record_video_dir="tests/outputs/videos/")
94-
page = context.new_page()
95-
page.goto(f"file://{html_path.resolve()}")
96-
page.wait_for_selector(".importance-chart-container canvas")
97-
canvas = page.query_selector(".importance-chart-container canvas")
98-
assert canvas is not None, "Canvas not found in importance chart container"
99-
assert canvas.is_visible(), "Canvas is not visible"
100-
box = canvas.bounding_box()
101-
x = box["x"] + box["width"] / 2
102-
y = box["y"] + box["height"] / 2
103-
page.mouse.click(x, y)
104-
page.wait_for_timeout(1000) # Wait to ensure video is recorded
105-
context.close() # This will save the video file
106-
browser.close()
102+
return html_content
107103

104+
@save_and_click_canvas_wrapper
108105
def test_fix_previous_bug():
109106
importanceData = {'features':
110107
['acv_score_canc_30d',
@@ -155,25 +152,43 @@ def test_fix_previous_bug():
155152
feature_encodings=featureEncodings,
156153
feature_names=featureNames,
157154
)
158-
html_path = store_report(html_content, "test_fix_previous_bug.html")
159-
with sync_playwright() as p:
160-
browser = p.chromium.launch()
161-
context = browser.new_context(record_video_dir="tests/outputs/videos/")
162-
page = context.new_page()
163-
page.goto(f"file://{html_path.resolve()}")
164-
page.wait_for_selector(".importance-chart-container canvas")
165-
canvas = page.query_selector(".importance-chart-container canvas")
166-
assert canvas is not None, "Canvas not found in importance chart container"
167-
assert canvas.is_visible(), "Canvas is not visible"
168-
box = canvas.bounding_box()
169-
x = box["x"] + box["width"] / 2
170-
y = box["y"] + box["height"] / 2
171-
page.mouse.click(x, y)
172-
page.wait_for_timeout(1000) # Wait to ensure video is recorded
173-
context.close() # This will save the video file
174-
browser.close()
175-
176-
# def test_importance_chart_canvas_click():
177-
# html_path = Path("tests/outputs/test_no_feature_encodings.html")
178-
# assert html_path.exists(), "Report HTML file does not exist. Run test_no_feature_encodings first."
155+
return html_content
156+
157+
158+
@save_and_click_canvas_wrapper
159+
def test_classification_case():
160+
X, y = shap.datasets.adult(n_points=200)
161+
162+
# Identify categorical columns
163+
categorical_cols = [col for col in X.columns if X[col].dtype == 'category' or X[col].dtype == 'object']
164+
numeric_cols = [col for col in X.columns if col not in categorical_cols]
165+
166+
label_encoders = {}
167+
168+
# Fill missing values manually
169+
for col in numeric_cols:
170+
X[col] = X[col].astype(float).fillna(X[col].mean())
171+
for col in categorical_cols:
172+
le = LabelEncoder()
173+
X[col + '_encoded'] = le.fit_transform(X[col].astype(str)) # convert to string in case of NaNs
174+
label_encoders[col] = le # Save encoder if needed later
179175

176+
# Train model
177+
rfc = RandomForestClassifier()
178+
rfc.fit(X, y)
179+
ex = shap.TreeExplainer(rfc)
180+
shap_values = ex(X)
181+
plugin = XaiflowPlugin()
182+
183+
feature_encodings = {}
184+
for col in categorical_cols:
185+
feature_encodings[col + '_encoded'] = dict(zip(range(len(label_encoders[col].classes_)), label_encoders[col].classes_))
186+
html_content = plugin._generate_html_content(
187+
importance_data={'features': list(X.columns), 'values': np.abs(shap_values.values).mean(axis=0).tolist()},
188+
shap_values=shap_values.values,
189+
feature_values=shap_values.data,
190+
base_values=shap_values.base_values[0],
191+
feature_encodings=feature_encodings,
192+
feature_names=list(X.columns),
193+
)
194+
return html_content

0 commit comments

Comments
 (0)