Skip to content

Commit a153240

Browse files
committed
fix classification feature importance, deepdive not yet working
1 parent 9b3ed84 commit a153240

File tree

8 files changed

+204
-211
lines changed

8 files changed

+204
-211
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "xaiflow"
7-
version = "0.1.0rc5"
7+
version = "0.1.0rc6"
88
description = "MLflow integration library for generating interactive HTML reports with SHAP analysis using Svelte and Chart.js"
99
authors = [
1010
{name = "CloudExplain Team", email = "tobias@cloudexplain.eu"}
@@ -41,7 +41,8 @@ dev = [
4141
"pytest-cov>=2.12.0",
4242
"black>=21.0.0",
4343
"flake8>=3.9.0",
44-
"playwright"
44+
"playwright",
45+
"pytest-mock",
4546
]
4647

4748
[project.urls]

src/xaiflow/mlflow_plugin.py

Lines changed: 25 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import os
77
import json
8+
import warnings
89
import tempfile
910
import shutil
1011
import mlflow
@@ -33,7 +34,8 @@ def log_feature_importance_report(
3334
importance_values: List[float] | np.ndarray = None,
3435
run_id: Optional[str] = None,
3536
artifact_path: str = "reports",
36-
report_name: str = "feature_importance_report.html"
37+
report_name: str = "feature_importance_report.html",
38+
round_decimals: int = 4
3739
) -> str:
3840
"""
3941
Log an interactive feature importance report as an MLflow artifact
@@ -45,16 +47,33 @@ def log_feature_importance_report(
4547
run_id: MLflow run ID (uses active run if None)
4648
artifact_path: Path within MLflow artifacts to store the report
4749
report_name: Name of the HTML report file
50+
round_decimals: Number of decimals to round feature values and SHAP values
4851
4952
Returns:
5053
str: Path to the logged artifact
5154
"""
5255

5356
if not isinstance(shap_values, Explanation):
5457
raise ValueError("shap_values must be an instance of shap.Explanation. Pls call explainer(X) or similar to get a valid Explanation object.")
55-
feature_values = shap_values.data
56-
base_values = shap_values.base_values
57-
shap_values = shap_values.values
58+
if np.issubdtype(shap_values.data.dtype, np.floating):
59+
feature_values = np.round(shap_values.data, round_decimals)
60+
else:
61+
feature_values = shap_values.data
62+
base_values = np.round(shap_values.base_values, round_decimals)[0]
63+
shap_values = np.round(shap_values.values, round_decimals)
64+
if feature_values.ndim != shap_values.ndim:
65+
if shap_values.ndim - feature_values.ndim > 1:
66+
NotImplementedError("It looks like you're using multi-target regression or multi-output classification. Currently we don't support this."
67+
" You can still use the plugin, just hand over shap_values[..., <target_index>] to get the SHAP values for a specific target/class."
68+
" Please ensure that the shap_values.dim - feature_values.dim is 1 or less.")
69+
else:
70+
warnings.warn("Feature values and SHAP values dimensions do not match. This can be due to multi-target regression or (multi-target) classification."
71+
" If you want a specific target/class, please hand over shap_values[..., <target_index>] to get the SHAP values for that target/class."
72+
" We fall back to shap_values[..., -1] in this case."
73+
)
74+
shap_values = shap_values[..., -1]
75+
base_values = float(base_values[-1])
76+
5877
# Use active run if no run_id provided
5978
if run_id is None:
6079
active_run = mlflow.active_run()
@@ -95,8 +114,8 @@ def log_feature_importance_report(
95114
with open(temp_path, 'w', encoding='utf-8') as f:
96115
f.write(html_content)
97116

98-
with open('test_report.html', 'w', encoding='utf-8') as f:
99-
f.write(html_content)
117+
# with open('test_report.html', 'w', encoding='utf-8') as f:
118+
# f.write(html_content)
100119

101120
# Log the report as an MLflow artifact
102121
artifact_full_path = f"{artifact_path}/{report_name}"
@@ -119,104 +138,6 @@ def log_feature_importance_report(
119138
if os.path.exists(temp_path):
120139
os.unlink(temp_path)
121140

122-
def log_model_explanation_report(
123-
self,
124-
model,
125-
X_test: np.ndarray,
126-
y_test: np.ndarray,
127-
feature_names: List[str],
128-
model_name: str = "model",
129-
run_id: Optional[str] = None,
130-
artifact_path: str = "reports",
131-
report_name: str = "model_explanation_report.html"
132-
) -> str:
133-
"""
134-
Log a comprehensive model explanation report with SHAP analysis
135-
136-
Args:
137-
model: Trained model object
138-
X_test: Test data features
139-
y_test: Test data targets
140-
feature_names: List of feature names
141-
model_name: Name of the model
142-
run_id: MLflow run ID (uses active run if None)
143-
artifact_path: Path within MLflow artifacts to store the report
144-
report_name: Name of the HTML report file
145-
146-
Returns:
147-
str: Path to the logged artifact
148-
"""
149-
150-
try:
151-
import shap
152-
except ImportError:
153-
raise ImportError("SHAP is required for model explanation reports. Install with: pip install shap")
154-
155-
# Use active run if no run_id provided
156-
if run_id is None:
157-
active_run = mlflow.active_run()
158-
if active_run is None:
159-
raise ValueError("No active MLflow run found. Please start a run or provide run_id.")
160-
run_id = active_run.info.run_id
161-
162-
# Calculate feature importance (if model supports it)
163-
try:
164-
if hasattr(model, 'feature_importances_'):
165-
importance_values = model.feature_importances_.tolist()
166-
elif hasattr(model, 'coef_'):
167-
importance_values = np.abs(model.coef_).flatten().tolist()
168-
else:
169-
# Use permutation importance as fallback
170-
from sklearn.inspection import permutation_importance
171-
perm_importance = permutation_importance(model, X_test, y_test, random_state=42)
172-
importance_values = perm_importance.importances_mean.tolist()
173-
except Exception as e:
174-
print(f"Warning: Could not calculate feature importance: {e}")
175-
importance_values = [1.0 / len(feature_names)] * len(feature_names)
176-
177-
# Calculate SHAP values
178-
try:
179-
# Use TreeExplainer for tree-based models, LinearExplainer for linear models
180-
if hasattr(model, 'tree_'):
181-
explainer = shap.TreeExplainer(model)
182-
else:
183-
explainer = shap.LinearExplainer(model, X_test)
184-
185-
# Calculate SHAP values for a subset of test data (for performance)
186-
sample_size = min(100, len(X_test))
187-
sample_indices = np.random.choice(len(X_test), sample_size, replace=False)
188-
X_sample = X_test[sample_indices]
189-
190-
shap_values_matrix = explainer.shap_values(X_sample)
191-
192-
# Handle multi-class case (take first class for now)
193-
if isinstance(shap_values_matrix, list):
194-
shap_values_matrix = shap_values_matrix[0]
195-
196-
shap_values = shap_values_matrix.tolist()
197-
198-
except Exception as e:
199-
print(f"Warning: Could not calculate SHAP values: {e}")
200-
# Generate dummy SHAP values
201-
sample_size = min(100, len(X_test))
202-
shap_values = []
203-
for _ in range(sample_size):
204-
sample_shap = [
205-
np.random.normal(0, abs(imp_val) * 0.1)
206-
for imp_val in importance_values
207-
]
208-
shap_values.append(sample_shap)
209-
210-
# Log the report
211-
return self.log_feature_importance_report(
212-
feature_names=feature_names,
213-
importance_values=importance_values,
214-
shap_values=shap_values,
215-
run_id=run_id,
216-
artifact_path=artifact_path,
217-
report_name=report_name
218-
)
219-
220141
def _generate_html_content(
221142
self,
222143
importance_data: Dict[str, Any],
@@ -289,72 +210,3 @@ def _generate_html_content(
289210
)
290211

291212
return html_content
292-
# Write to file
293-
294-
def get_report_url(self, run_id: str, artifact_path: str = "reports", report_name: str = "feature_importance_report.html") -> str:
295-
"""
296-
Get the MLflow UI URL for viewing the report
297-
298-
Args:
299-
run_id: MLflow run ID
300-
artifact_path: Path within MLflow artifacts where the report is stored
301-
report_name: Name of the HTML report file
302-
303-
Returns:
304-
str: URL to view the report in MLflow UI
305-
"""
306-
307-
# Get the MLflow tracking URI
308-
tracking_uri = mlflow.get_tracking_uri()
309-
310-
# Construct the artifact URL
311-
artifact_full_path = f"{artifact_path}/{report_name}"
312-
313-
if tracking_uri.startswith("http"):
314-
# Remote MLflow server
315-
base_url = tracking_uri.rstrip('/')
316-
url = f"{base_url}/#/experiments/runs/{run_id}/artifacts/{artifact_full_path}"
317-
else:
318-
# Local MLflow server (assume default port 5000)
319-
url = f"http://localhost:5000/#/experiments/runs/{run_id}/artifacts/{artifact_full_path}"
320-
321-
return url
322-
323-
324-
# Convenience functions for easy usage
325-
def log_feature_importance(
326-
feature_names: List[str],
327-
importance_values: List[float],
328-
shap_values: Optional[List[List[float]]] = None,
329-
**kwargs
330-
) -> str:
331-
"""
332-
Convenience function to log feature importance report to MLflow
333-
"""
334-
plugin = XaiflowPlugin()
335-
return plugin.log_feature_importance_report(
336-
feature_names=feature_names,
337-
importance_values=importance_values,
338-
shap_values=shap_values,
339-
**kwargs
340-
)
341-
342-
343-
def log_model_explanation(
344-
model,
345-
X_test: np.ndarray,
346-
y_test: np.ndarray,
347-
feature_names: List[str],
348-
**kwargs
349-
) -> str:
350-
"""
351-
Convenience function to log model explanation report to MLflow
352-
"""
353-
plugin = XaiflowPlugin()
354-
return plugin.log_model_explanation_report(
355-
model=model,
356-
X_test=X_test,
357-
y_test=y_test,
358-
feature_names=feature_names,
359-
**kwargs
360-
)

src/xaiflow/templates/assets/bundle.js

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/xaiflow/templates/assets/bundle.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/xaiflow/templates/components/DeepDiveChart.svelte

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
1111
// Register the necessary components
1212
Chart.register(BarController, BarElement, CategoryScale, LinearScale, Title, Tooltip, Legend);
13+
console.log("DeepDiveChart: NEWNEWNEW Initialized Chart.js components");
1314
1415
interface Props {
1516
shapValues: number[][];
@@ -85,11 +86,11 @@
8586
}
8687
8788
function updateChart(singleShapValues: number[]) {
88-
console.log("updateChart called with singleShapValues:", singleShapValues, "and singleFeatureValues:", singleFeatureValues);
89+
console.log("DeepDiveChart: updateChart called with singleShapValues:", singleShapValues, "and singleFeatureValues:", singleFeatureValues);
8990
maxOfData = Math.max(...singleShapValues);
9091
minOfData = Math.min(...singleShapValues);
9192
const screen = getScreenSizeFlags();
92-
console.log("Screen size flags:", screen);
93+
console.log("DeepDiveChart: Screen size flags:", screen);
9394
9495
// Color mapping based on isHigherOutputBetter prop
9596
pointBackgroundColor = singleShapValues.map(d => {
@@ -99,20 +100,29 @@
99100
// If higher output is NOT better, use normal color mapping (red=high, green=low)
100101
const colorValue = isHigherOutputBetter ? (100 - normalizedValue) : normalizedValue;
101102
103+
console.log("Normalized Value:", normalizedValue, "Color Value:", colorValue);
102104
return colorMap(colorValue);
103105
// return colorValue;
104106
});
105107
106108
console.log("Max of Data", maxOfData);
107109
console.log("Min of Data", minOfData);
108-
110+
109111
if (chart) {
110112
chart.data.labels = featureNames;
113+
console.log("DeepDiveChart: NEW Updating chart with new data", singleShapValues, featureNames, base_value);
111114
cumulativeValues = createCumulativeStartEndRangesFromValues(singleShapValues, base_value)
115+
console.log("DeepDiveChart: NEW 2 Updating chart with new data", cumulativeValues);
112116
chart.data.datasets[0].data = cumulativeValues;
113117
maxCumulativeValue = Math.max(...cumulativeValues.map(d => d[1]));
114-
console.log("Max Cumulative Value", maxCumulativeValue);
115118
chart.data.datasets[0].backgroundColor = pointBackgroundColor;
119+
// Dynamically update y-axis min and max
120+
console.log("DeepDiveChart: Updating chart with new data", cumulativeValues, pointBackgroundColor);
121+
if (chart.options.scales?.y) {
122+
console.log("DeepDiveChart: Updating y-axis min and max to ", Math.floor(minOfData), Math.ceil(maxOfData * 1.05));
123+
chart.options.scales.y.min = Math.floor(minOfData);
124+
chart.options.scales.y.max = Math.ceil(maxOfData * 1.05);
125+
}
116126
// Update x-axis rotation based on screen size
117127
if (chart.options.scales?.x?.ticks) {
118128
chart.options.scales.x.ticks.maxRotation = screen.isLargeScreen ? 45 : 90;
@@ -247,8 +257,8 @@
247257
}
248258
},
249259
y: {
250-
min: 0,
251-
max: Math.floor(maxCumulativeValue * 1.3),
260+
min: Math.floor(minOfData * 0.95),
261+
max: Math.ceil(maxOfData * 1.05),
252262
ticks: {
253263
font: {
254264
size: window.innerWidth < 768 ? 8 : window.innerWidth < 1024 ? 10 : 12
@@ -401,7 +411,7 @@
401411
});
402412
</script>
403413

404-
<canvas bind:this={chartCanvas}></canvas>
414+
<canvas id="deepdive-canvas" bind:this={chartCanvas}></canvas>
405415

406416
<style>
407417
canvas {

src/xaiflow/templates/components/DeepDiveManager.svelte

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
</script>
105105

106106
<div>
107-
<div class="observation-dropdown" style="position:relative;max-width:300px;">
107+
<div class="deepdive-observation-dropdown" style="position:relative;max-width:300px;">
108108
<label for="observation-filter">Filter Observations:</label>
109109
<input id="observation-filter" type="text" bind:value={filterText} placeholder="Type to filter..."
110110
on:focus={handleInputFocus} on:blur={handleInputBlur} autocomplete="off" />

src/xaiflow/templates/utils/utils.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
export function createCumulativeStartEndRangesFromValues(array: number[], base_value: number): [number, number][] {
22
let cumulative = 0;
33
return array.map((value, index) => {
4+
console.log("createCumulativeStartEndRangesFromValues: index", index, "value", value, "cumulative", cumulative, "base_value", base_value);
45
if (index === 0) {
56
cumulative = base_value;
67
}

0 commit comments

Comments
 (0)