@@ -207,59 +207,6 @@ def test_no_feature_encodings():
207207 )
208208 return html_content
209209
210- @save_and_click_canvas_wrapper
211- def test_fix_previous_bug ():
212- importanceData = {'features' :
213- ['acv_score_canc_30d' ,
214- 'avg_canc_dealer_no_weighted' ,
215- 'ctr_usa_sec_inc_voice_a6m' ,
216- 'avg_canc_reseller_id_weighted' ,
217- 'ctr_usa_kb_data_usg_a3m' ,
218- 'ctr_sales_channel_current' ,
219- 'ctr_cancellations_per_year' ,
220- 'avg_vvl_reseller_id_weighted' ,
221- 'ctr_start_days' ,
222- 'ctr_min_duration_date_crm_days' ,
223- 'rlz' ,
224- 'vvl_l_event_days' ,
225- 'avg_vvl_sales_channel' ,
226- 'ctr_dealer_no_current' ,
227- 'avg_canc_sales_channel' ,
228- 'prt_cancellation_page_visit_90d_count' ,
229- 'acv_score_vvl_30d' ],
230- 'values' : [0.5000000000000614 ,
231- 0.49999999999993844 ]}
232- shapValues = [[0.05231021109253422 , - 0.05231021109253736 ], [0.0073606489440402965 , - 0.007360648944034653 ], [- 0.01633880222219225 , 0.016338802222170094 ], [0.012322311243639975 , - 0.012322311243637033 ], [- 0.004445322661143976 , 0.004445322661143468 ], [0.0009611405151175154 , - 0.0009611405151178431 ], [0.005596997502669034 , - 0.0055969975026683915 ], [- 0.0008618250588141368 , 0.0008618250587932731 ], [0.0016991238750754237 , - 0.0016991238750824476 ], [0.0048252568432011304 , - 0.004825256843199152 ], [- 0.00038499217151075256 , 0.00038499217151299176 ], [0.005172501948575322 , - 0.005172501948575318 ], [- 0.003383349580534422 , 0.003383349580535079 ], [- 0.017147577240666855 , 0.017147577240670973 ], [0.008064862968425773 , - 0.008064862968423504 ], [0.0018500348673166761 , - 0.0018500348673163927 ], [0.006529750924148127 , - 0.006529750924151195 ]]
233- featureValues = [0.03421833738684654 , 0.022704629679359795 , 15.0 , 0.022704629679359795 , 30193.0 , 241.0 , 0.0 , 0.022739316468840073 , 951.5416666666666 , 2912717.0 , 30.0 , 9999.0 , 0.019356054262267625 , 9.0 , 0.02191634567074192 , 0.0 , 0.2959745228290558 ]
234- baseValues = [0.9058690282100686 , 0.09413097178993132 ]
235- featureEncodings = None
236- featureNames = ['acv_score_canc_30d' ,
237- 'avg_canc_dealer_no_weighted' ,
238- 'ctr_usa_sec_inc_voice_a6m' ,
239- 'avg_canc_reseller_id_weighted' ,
240- 'ctr_usa_kb_data_usg_a3m' ,
241- 'ctr_sales_channel_current' ,
242- 'ctr_cancellations_per_year' ,
243- 'avg_vvl_reseller_id_weighted' ,
244- 'ctr_start_days' ,
245- 'ctr_min_duration_date_crm_days' ,
246- 'rlz' ,
247- 'vvl_l_event_days' ,
248- 'avg_vvl_sales_channel' ,
249- 'ctr_dealer_no_current' ,
250- 'avg_canc_sales_channel' ,
251- 'prt_cancellation_page_visit_90d_count' ,
252- 'acv_score_vvl_30d' ]
253- plugin = XaiflowPlugin ()
254- html_content = plugin ._generate_html_content (
255- importance_data = importanceData ,
256- shap_values = shapValues ,
257- feature_values = featureValues ,
258- feature_encodings = featureEncodings ,
259- feature_names = featureNames ,
260- )
261- return html_content
262-
263210
264211def test_classification_case (mocker ):
265212 X , y = shap .datasets .adult (n_points = 200 )
@@ -322,4 +269,61 @@ def __exit__(self, exc_type, exc_val, exc_tb):
322269 feature_names = list (X .columns ),
323270 )
324271 html_content_click_test (Path (output_path ))
272+ # return html_content
273+
274+
275+ def test_classification_case_check_list_feature (mocker ):
276+ X , y = shap .datasets .adult (n_points = 200 )
277+
278+ # Identify categorical columns
279+ categorical_cols = [col for col in X .columns if X [col ].dtype == 'category' or X [col ].dtype == 'object' ]
280+ numeric_cols = [col for col in X .columns if col not in categorical_cols ]
281+
282+ label_encoders = {}
283+
284+ # Fill missing values manually
285+ for col in numeric_cols :
286+ X [col ] = X [col ].astype (float ).fillna (X [col ].mean ())
287+ for col in categorical_cols :
288+ le = LabelEncoder ()
289+ X [col + '_encoded' ] = le .fit_transform (X [col ].astype (str )) # convert to string in case of NaNs
290+ label_encoders [col ] = le # Save encoder if needed later
291+
292+ # Train model
293+ rfc = RandomForestClassifier ()
294+ rfc .fit (X , y )
295+ ex = shap .TreeExplainer (rfc )
296+ shap_values = ex (X )
297+ plugin = XaiflowPlugin ()
298+
299+ feature_encodings = {}
300+ for col in categorical_cols :
301+ feature_encodings [col + '_encoded' ] = dict (zip (range (len (label_encoders [col ].classes_ )), label_encoders [col ].classes_ ))
302+ experiment_name = "dummytest"
303+ mlflow .set_experiment (experiment_name = experiment_name )
304+
305+ output_path = f"tests/outputs/test_classification_case_check_list_feature.html"
306+ class DummyTmpFile :
307+ name = output_path
308+ def __enter__ (self ):
309+ self .name = output_path
310+ # import pdb; pdb.set_trace() # Debugging breakpoint
311+ return self
312+ def __exit__ (self , exc_type , exc_val , exc_tb ):
313+ pass
314+
315+ mocker .patch ("tempfile.NamedTemporaryFile" , return_value = DummyTmpFile ())
316+ mocker .patch ("os.unlink" ) # Prevent deletion
317+
318+ # Optionally patch mlflow.log_artifact if you want to avoid real logging
319+ mocker .patch ("mlflow.log_artifact" )
320+
321+ with mlflow .start_run (run_name = "auto_mpg_test" ):
322+ plugin .log_feature_importance_report (
323+ shap_values = shap_values ,
324+ feature_encodings = feature_encodings ,
325+ feature_names = list (X .columns ),
326+ group_labels = ["Group 1" , "Group 2" , "Group 3" , "Group 4" ] * int (len (shap_values ) / 4 ) # Example group labels
327+ )
328+ html_content_click_test (Path (output_path ))
325329 # return html_content
0 commit comments