55
66import os
77import json
8+ import warnings
89import tempfile
910import shutil
1011import mlflow
@@ -33,7 +34,8 @@ def log_feature_importance_report(
3334 importance_values : List [float ] | np .ndarray = None ,
3435 run_id : Optional [str ] = None ,
3536 artifact_path : str = "reports" ,
36- report_name : str = "feature_importance_report.html"
37+ report_name : str = "feature_importance_report.html" ,
38+ round_decimals : int = 4
3739 ) -> str :
3840 """
3941 Log an interactive feature importance report as an MLflow artifact
@@ -45,16 +47,33 @@ def log_feature_importance_report(
4547 run_id: MLflow run ID (uses active run if None)
4648 artifact_path: Path within MLflow artifacts to store the report
4749 report_name: Name of the HTML report file
50+ round_decimals: Number of decimals to round feature values and SHAP values
4851
4952 Returns:
5053 str: Path to the logged artifact
5154 """
5255
5356 if not isinstance (shap_values , Explanation ):
5457 raise ValueError ("shap_values must be an instance of shap.Explanation. Pls call explainer(X) or similar to get a valid Explanation object." )
55- feature_values = shap_values .data
56- base_values = shap_values .base_values
57- shap_values = shap_values .values
58+ if np .issubdtype (shap_values .data .dtype , np .floating ):
59+ feature_values = np .round (shap_values .data , round_decimals )
60+ else :
61+ feature_values = shap_values .data
62+ base_values = np .round (shap_values .base_values , round_decimals )[0 ]
63+ shap_values = np .round (shap_values .values , round_decimals )
64+ if feature_values .ndim != shap_values .ndim :
65+ if shap_values .ndim - feature_values .ndim > 1 :
66+ NotImplementedError ("It looks like you're using multi-target regression or multi-output classification. Currently we don't support this."
67+ " You can still use the plugin, just hand over shap_values[..., <target_index>] to get the SHAP values for a specific target/class."
68+ " Please ensure that the shap_values.dim - feature_values.dim is 1 or less." )
69+ else :
70+ warnings .warn ("Feature values and SHAP values dimensions do not match. This can be due to multi-target regression or (multi-target) classification."
71+ " If you want a specific target/class, please hand over shap_values[..., <target_index>] to get the SHAP values for that target/class."
72+ " We fall back to shap_values[..., -1] in this case."
73+ )
74+ shap_values = shap_values [..., - 1 ]
75+ base_values = float (base_values [- 1 ])
76+
5877 # Use active run if no run_id provided
5978 if run_id is None :
6079 active_run = mlflow .active_run ()
@@ -95,8 +114,8 @@ def log_feature_importance_report(
95114 with open (temp_path , 'w' , encoding = 'utf-8' ) as f :
96115 f .write (html_content )
97116
98- with open ('test_report.html' , 'w' , encoding = 'utf-8' ) as f :
99- f .write (html_content )
117+ # with open('test_report.html', 'w', encoding='utf-8') as f:
118+ # f.write(html_content)
100119
101120 # Log the report as an MLflow artifact
102121 artifact_full_path = f"{ artifact_path } /{ report_name } "
@@ -119,104 +138,6 @@ def log_feature_importance_report(
119138 if os .path .exists (temp_path ):
120139 os .unlink (temp_path )
121140
122- def log_model_explanation_report (
123- self ,
124- model ,
125- X_test : np .ndarray ,
126- y_test : np .ndarray ,
127- feature_names : List [str ],
128- model_name : str = "model" ,
129- run_id : Optional [str ] = None ,
130- artifact_path : str = "reports" ,
131- report_name : str = "model_explanation_report.html"
132- ) -> str :
133- """
134- Log a comprehensive model explanation report with SHAP analysis
135-
136- Args:
137- model: Trained model object
138- X_test: Test data features
139- y_test: Test data targets
140- feature_names: List of feature names
141- model_name: Name of the model
142- run_id: MLflow run ID (uses active run if None)
143- artifact_path: Path within MLflow artifacts to store the report
144- report_name: Name of the HTML report file
145-
146- Returns:
147- str: Path to the logged artifact
148- """
149-
150- try :
151- import shap
152- except ImportError :
153- raise ImportError ("SHAP is required for model explanation reports. Install with: pip install shap" )
154-
155- # Use active run if no run_id provided
156- if run_id is None :
157- active_run = mlflow .active_run ()
158- if active_run is None :
159- raise ValueError ("No active MLflow run found. Please start a run or provide run_id." )
160- run_id = active_run .info .run_id
161-
162- # Calculate feature importance (if model supports it)
163- try :
164- if hasattr (model , 'feature_importances_' ):
165- importance_values = model .feature_importances_ .tolist ()
166- elif hasattr (model , 'coef_' ):
167- importance_values = np .abs (model .coef_ ).flatten ().tolist ()
168- else :
169- # Use permutation importance as fallback
170- from sklearn .inspection import permutation_importance
171- perm_importance = permutation_importance (model , X_test , y_test , random_state = 42 )
172- importance_values = perm_importance .importances_mean .tolist ()
173- except Exception as e :
174- print (f"Warning: Could not calculate feature importance: { e } " )
175- importance_values = [1.0 / len (feature_names )] * len (feature_names )
176-
177- # Calculate SHAP values
178- try :
179- # Use TreeExplainer for tree-based models, LinearExplainer for linear models
180- if hasattr (model , 'tree_' ):
181- explainer = shap .TreeExplainer (model )
182- else :
183- explainer = shap .LinearExplainer (model , X_test )
184-
185- # Calculate SHAP values for a subset of test data (for performance)
186- sample_size = min (100 , len (X_test ))
187- sample_indices = np .random .choice (len (X_test ), sample_size , replace = False )
188- X_sample = X_test [sample_indices ]
189-
190- shap_values_matrix = explainer .shap_values (X_sample )
191-
192- # Handle multi-class case (take first class for now)
193- if isinstance (shap_values_matrix , list ):
194- shap_values_matrix = shap_values_matrix [0 ]
195-
196- shap_values = shap_values_matrix .tolist ()
197-
198- except Exception as e :
199- print (f"Warning: Could not calculate SHAP values: { e } " )
200- # Generate dummy SHAP values
201- sample_size = min (100 , len (X_test ))
202- shap_values = []
203- for _ in range (sample_size ):
204- sample_shap = [
205- np .random .normal (0 , abs (imp_val ) * 0.1 )
206- for imp_val in importance_values
207- ]
208- shap_values .append (sample_shap )
209-
210- # Log the report
211- return self .log_feature_importance_report (
212- feature_names = feature_names ,
213- importance_values = importance_values ,
214- shap_values = shap_values ,
215- run_id = run_id ,
216- artifact_path = artifact_path ,
217- report_name = report_name
218- )
219-
220141 def _generate_html_content (
221142 self ,
222143 importance_data : Dict [str , Any ],
@@ -289,72 +210,3 @@ def _generate_html_content(
289210 )
290211
291212 return html_content
292- # Write to file
293-
294- def get_report_url (self , run_id : str , artifact_path : str = "reports" , report_name : str = "feature_importance_report.html" ) -> str :
295- """
296- Get the MLflow UI URL for viewing the report
297-
298- Args:
299- run_id: MLflow run ID
300- artifact_path: Path within MLflow artifacts where the report is stored
301- report_name: Name of the HTML report file
302-
303- Returns:
304- str: URL to view the report in MLflow UI
305- """
306-
307- # Get the MLflow tracking URI
308- tracking_uri = mlflow .get_tracking_uri ()
309-
310- # Construct the artifact URL
311- artifact_full_path = f"{ artifact_path } /{ report_name } "
312-
313- if tracking_uri .startswith ("http" ):
314- # Remote MLflow server
315- base_url = tracking_uri .rstrip ('/' )
316- url = f"{ base_url } /#/experiments/runs/{ run_id } /artifacts/{ artifact_full_path } "
317- else :
318- # Local MLflow server (assume default port 5000)
319- url = f"http://localhost:5000/#/experiments/runs/{ run_id } /artifacts/{ artifact_full_path } "
320-
321- return url
322-
323-
324- # Convenience functions for easy usage
325- def log_feature_importance (
326- feature_names : List [str ],
327- importance_values : List [float ],
328- shap_values : Optional [List [List [float ]]] = None ,
329- ** kwargs
330- ) -> str :
331- """
332- Convenience function to log feature importance report to MLflow
333- """
334- plugin = XaiflowPlugin ()
335- return plugin .log_feature_importance_report (
336- feature_names = feature_names ,
337- importance_values = importance_values ,
338- shap_values = shap_values ,
339- ** kwargs
340- )
341-
342-
343- def log_model_explanation (
344- model ,
345- X_test : np .ndarray ,
346- y_test : np .ndarray ,
347- feature_names : List [str ],
348- ** kwargs
349- ) -> str :
350- """
351- Convenience function to log model explanation report to MLflow
352- """
353- plugin = XaiflowPlugin ()
354- return plugin .log_model_explanation_report (
355- model = model ,
356- X_test = X_test ,
357- y_test = y_test ,
358- feature_names = feature_names ,
359- ** kwargs
360- )
0 commit comments