Skip to content

Commit 8251ec5

Browse files
authored
Final cleanups and add video to description (#4)
* add video * change readme * add video * add video * add video * add video * add video * add video * add video * add video * update example
1 parent 7d99e87 commit 8251ec5

File tree

8 files changed

+29
-33
lines changed

8 files changed

+29
-33
lines changed

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
77
xaiflow integrates seamlessly with MLflow to generate interactive HTML reports for SHAP analysis. Instead of static charts and images, you get rich, interactive visualizations that stakeholders can explore and understand.
88

9+
Here should the video go:
10+
[![xaiflow showcase](video/video_thumbnail.png)](https://github.com/user-attachments/assets/f508fa6f-ab0f-493d-a892-ed958331e30a)
11+
*Click the image above to watch the feature showcase video.*
12+
913
## What We're Trying to Achieve
1014

1115
Most ML workflows produce explanations as static images or basic charts, which creates several problems:
@@ -38,10 +42,9 @@ with mlflow.start_run():
3842

3943
# Add interactive explainable AI reports
4044
plugin = XaiflowPlugin()
41-
plugin.log_feature_importance_report(
45+
plugin.log_xai_report(
4246
feature_names=X.columns.tolist(),
4347
shap_values=shap_values,
44-
report_name="model_explanation.html"
4548
)
4649
```
4750

@@ -81,7 +84,7 @@ xaiflow/
8184
### Core Components
8285

8386
**MLflow Integration** (`mlflow_plugin.py`)
84-
The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_feature_importance_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts.
87+
The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_xai_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts.
8588

8689
**Report Generation** (`report_generator.py`)
8790
The `ReportGenerator` class converts SHAP data into interactive HTML reports using Jinja2 templating. It handles template loading, asset bundling, and data injection into the frontend components.
@@ -127,11 +130,10 @@ feature_encodings = {
127130
'region': {0: 'North', 1: 'South', 2: 'East', 3: 'West'}
128131
}
129132

130-
plugin.log_feature_importance_report(
133+
plugin.log_xai_report(
131134
feature_names=feature_names,
132135
shap_values=shap_values,
133136
feature_encodings=feature_encodings,
134-
report_name="enhanced_report.html"
135137
)
136138
```
137139

examples/notebooks/auto_mpg_example.ipynb

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
5+
"execution_count": 1,
66
"id": "79de156f",
77
"metadata": {},
88
"outputs": [
99
{
1010
"name": "stderr",
1111
"output_type": "stream",
1212
"text": [
13-
"/home/tobias/programming/cloudexplain/ce-mlflow-extension/ce-mlflow-extension/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13+
"c:\\programming\\cloudexplain\\xflow\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
1414
" from .autonotebook import tqdm as notebook_tqdm\n"
1515
]
1616
}
@@ -42,21 +42,19 @@
4242
},
4343
{
4444
"cell_type": "code",
45-
"execution_count": null,
45+
"execution_count": 3,
4646
"id": "0da5d7e2",
4747
"metadata": {},
4848
"outputs": [
4949
{
5050
"name": "stdout",
5151
"output_type": "stream",
5252
"text": [
53-
"Loaded bundle.js content (218107 characters)\n",
54-
"Saved report data to test_report_data.json\n",
55-
"logged to test_report.html\n",
56-
"Feature importance report logged to MLflow: reports/test_report_auto_mpg.html\n",
57-
"Run ID: 72ea715bfe9c42bc840388933f6999a8. If you are running mlflow locally use:\n",
53+
"Loaded bundle.js content (225719 characters)\n",
54+
"Feature importance report logged to MLflow: reports/feature_importance_report.html\n",
55+
"Run ID: 7521c3f260f84a5d8e038a13bc91498b. If you are running mlflow locally use:\n",
5856
"python -m mlflow ui --port 5000\n",
59-
"Then open http://localhost:5000/#/experiments/921177506761828334/runs/72ea715bfe9c42bc840388933f6999a8 to view the report.\n"
57+
"Then open http://localhost:5000/#/experiments/557047036753041520/runs/7521c3f260f84a5d8e038a13bc91498b to view the report. Note: it's important to start mlflow in the directory in which you execute the notebook.\n"
6058
]
6159
}
6260
],
@@ -91,10 +89,9 @@
9189
" feature_encodings = {'cylinders_encoded': {0: '3', 1: '4', 2: '5', 3: '6', 4: '8'},\n",
9290
" 'model_encoded': {0: 'Super 70', 1: 'Super 71', 2: 'Low 72', 3: 'Nice 73', 4: 'Great 74', 5: 'Lame 75', 6: 'High 76', 7: '77', 8: '78', 9: '79', 10: '80', 11: '81', 12: '82'},\n",
9391
" 'origin_encoded': {0: 'Afghanistan', 1: 'Bangladesh', 2: 'Maui'}}\n",
94-
" artifact_path = plugin.log_feature_importance_report(\n",
92+
" artifact_path = plugin.log_xai_report(\n",
9593
" feature_names=list(X.columns),\n",
9694
" shap_values=shap_values,\n",
97-
" report_name=\"test_report_auto_mpg.html\",\n",
9895
" feature_encodings=feature_encodings\n",
9996
" )\n",
10097
" run_id = mlflow.active_run().info.run_id\n",
@@ -119,7 +116,7 @@
119116
"name": "python",
120117
"nbconvert_exporter": "python",
121118
"pygments_lexer": "ipython3",
122-
"version": "3.12.3"
119+
"version": "3.13.5"
123120
}
124121
},
125122
"nbformat": 4,

examples/notebooks/azure_ml_auto_mpg_example.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,9 @@
258258
" feature_encodings = {'cylinders_encoded': {0: '3', 1: '4', 2: '5', 3: '6', 4: '8'},\n",
259259
" 'model_encoded': {0: 'Super 70', 1: 'Super 71', 2: 'Low 72', 3: 'Nice 73', 4: 'Great 74', 5: 'Lame 75', 6: 'High 76', 7: '77', 8: '78', 9: '79', 10: '80', 11: '81', 12: '82'},\n",
260260
" 'origin_encoded': {0: 'Afghanistan', 1: 'Bangladesh', 2: 'Maui'}}\n",
261-
" artifact_path = plugin.log_feature_importance_report(\n",
261+
" artifact_path = plugin.log_xai_report(\n",
262262
" feature_names=list(X.columns),\n",
263263
" shap_values=shap_values,\n",
264-
" report_name=\"test_report_auto_mpg.html\",\n",
265264
" feature_encodings=feature_encodings\n",
266265
" )\n",
267266
" run_id = mlflow.active_run().info.run_id\n",

examples/scripts/auto_mpg_example.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@
4343
feature_encodings = {'cylinders_encoded': {0: '3', 1: '4', 2: '5', 3: '6', 4: '8'},
4444
'model_encoded': {0: 'Super 70', 1: 'Super 71', 2: 'Low 72', 3: 'Nice 73', 4: 'Great 74', 5: 'Lame 75', 6: 'High 76', 7: '77', 8: '78', 9: '79', 10: '80', 11: '81', 12: '82'},
4545
'origin_encoded': {0: 'Afghanistan', 1: 'Bangladesh', 2: 'Maui'}}
46-
artifact_path = plugin.log_feature_importance_report(
46+
artifact_path = plugin.log_xai_report(
4747
feature_names=list(X.columns),
4848
shap_values=shap_values,
49-
report_name="test_report_auto_mpg.html",
50-
feature_encodings=feature_encodings
49+
feature_encodings=feature_encodings,
50+
# assign each sample to a custom group label
51+
group_labels=["Custom Group " + str(i % 4) for i in range(len(X))],
5152
)
5253
run_id = mlflow.active_run().info.run_id
5354
print(f"Run ID: {run_id}. If you are running mlflow locally use:\npython -m mlflow ui --port 5000\nThen open http://localhost:5000/#/experiments/{mlflow.get_experiment_by_name(experiment_name).experiment_id}/runs/{run_id} to view the report.",

src/xaiflow/mlflow_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self):
2626
self.template_dir = os.path.join(os.path.dirname(__file__), 'templates')
2727
self.env = Environment(loader=FileSystemLoader(self.template_dir))
2828

29-
def log_feature_importance_report(
29+
def log_xai_report(
3030
self,
3131
feature_names: List[str],
3232
shap_values: Explanation,

src/xflow.egg-info/PKG-INFO

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ with mlflow.start_run():
8080

8181
# Add interactive explainable AI reports
8282
plugin = CEMLflowPlugin()
83-
plugin.log_feature_importance_report(
83+
plugin.log_xai_report(
8484
feature_names=X.columns.tolist(),
8585
shap_values=shap_values,
8686
report_name="model_explanation.html"
@@ -121,7 +121,7 @@ xaiflow/
121121
### Core Components
122122

123123
**MLflow Integration** (`mlflow_plugin.py`)
124-
The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_feature_importance_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts.
124+
The `CEMLflowPlugin` class handles the integration with MLflow. The main method `log_xai_report()` processes SHAP values, manages feature encodings, and stores the generated reports as MLflow artifacts.
125125

126126
**Report Generation** (`report_generator.py`)
127127
The `ReportGenerator` class converts SHAP data into interactive HTML reports using Jinja2 templating. It handles template loading, asset bundling, and data injection into the frontend components.
@@ -167,7 +167,7 @@ feature_encodings = {
167167
'region': {0: 'North', 1: 'South', 2: 'East', 3: 'West'}
168168
}
169169

170-
plugin.log_feature_importance_report(
170+
plugin.log_xai_report(
171171
feature_names=feature_names,
172172
shap_values=shap_values,
173173
feature_encodings=feature_encodings,

tests/test_mlflow_plugin.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import numpy as np
55
from sklearn.datasets import fetch_openml
66
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
7-
from xgboost import XGBClassifier
8-
from catboost import CatBoostClassifier
97
from sklearn.preprocessing import LabelEncoder
108
import shap
119
from typing import Callable
@@ -187,7 +185,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
187185
mocker.patch("mlflow.log_artifact")
188186

189187
with mlflow.start_run(run_name="auto_mpg_test"):
190-
plugin.log_feature_importance_report(
188+
plugin.log_xai_report(
191189
shap_values=shap_values,
192190
feature_encodings=feature_encodings,
193191
feature_names=list(X.columns),
@@ -263,7 +261,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
263261
mocker.patch("mlflow.log_artifact")
264262

265263
with mlflow.start_run(run_name="auto_mpg_test"):
266-
plugin.log_feature_importance_report(
264+
plugin.log_xai_report(
267265
shap_values=shap_values,
268266
feature_encodings=feature_encodings,
269267
feature_names=list(X.columns),
@@ -319,11 +317,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
319317
mocker.patch("mlflow.log_artifact")
320318

321319
with mlflow.start_run(run_name="auto_mpg_test"):
322-
plugin.log_feature_importance_report(
320+
plugin.log_xai_report(
323321
shap_values=shap_values,
324322
feature_encodings=feature_encodings,
325323
feature_names=list(X.columns),
326324
group_labels=["Group 1", "Group 2", "Group 3", "Group 4"] * int(len(shap_values) / 4) # Example group labels
327325
)
328-
html_content_click_test(Path(output_path))
329-
# return html_content
326+
html_content_click_test(Path(output_path))

video/video_thumbnail.png

189 KB
Loading

0 commit comments

Comments
 (0)