11import pytest
22from pathlib import Path
3- from xaiflow .report_generator import ReportGenerator
43from xaiflow .mlflow_plugin import XaiflowPlugin
54import numpy as np
65from sklearn .datasets import fetch_openml
7- from sklearn .ensemble import RandomForestRegressor
6+ from sklearn .ensemble import RandomForestRegressor , RandomForestClassifier
7+ from xgboost import XGBClassifier
8+ from catboost import CatBoostClassifier
89from sklearn .preprocessing import LabelEncoder
910import shap
11+ from typing import Callable
1012
1113from playwright .sync_api import sync_playwright
1214
15+
1316def store_report (html_content , filename = "test_report.html" ):
1417 """Helper function to store HTML content to a file."""
1518 outputs_dir = Path ("tests/outputs" )
@@ -19,6 +22,30 @@ def store_report(html_content, filename="test_report.html"):
1922 f .write (html_content )
2023 return html_path
2124
25+ def save_and_click_canvas_wrapper (func : Callable ) -> Callable :
26+ def wrapper (* args , ** kwargs ):
27+ html_content = func (* args , ** kwargs )
28+ html_path = store_report (html_content , func .__name__ + ".html" )
29+ with sync_playwright () as p :
30+ browser = p .chromium .launch ()
31+ context = browser .new_context (record_video_dir = "tests/outputs/videos/" )
32+ page = context .new_page ()
33+ page .goto (f"file://{ html_path .resolve ()} " )
34+ page .wait_for_selector (".importance-chart-container canvas" )
35+ canvas = page .query_selector (".importance-chart-container canvas" )
36+ assert canvas is not None , "Canvas not found in importance chart container"
37+ assert canvas .is_visible (), "Canvas is not visible"
38+ box = canvas .bounding_box ()
39+ x = box ["x" ] + box ["width" ] / 2
40+ y = box ["y" ] + box ["height" ] / 2
41+ page .mouse .click (x , y )
42+ page .wait_for_timeout (1000 )
43+ # Check if canvas is not empty
44+ is_empty = canvas .evaluate ("(node) => {\n const ctx = node.getContext('2d');\n const data = ctx.getImageData(0, 0, node.width, node.height).data;\n const totalPixels = data.length / 4;\n if (totalPixels === 0) return true;\n // Get first pixel color\n const r0 = data[0], g0 = data[1], b0 = data[2], a0 = data[3];\n let sameCount = 0;\n for (let i = 0; i < data.length; i += 4) {\n const r = data[i], g = data[i+1], b = data[i+2], a = data[i+3];\n if (r === r0 && g === g0 && b === b0 && a === a0) {\n sameCount++;\n }\n }\n return (sameCount / totalPixels) >= 0.9;\n }" )
45+ assert not is_empty , "Canvas is empty or nearly empty: 99% of pixels have the same color"
46+ return wrapper
47+
48+ @save_and_click_canvas_wrapper
2249def test_categorical_feature_encodings ():
2350 data = fetch_openml (data_id = 196 , as_frame = True )
2451 X = data .data .copy ()[:100 ]
@@ -59,25 +86,10 @@ def test_categorical_feature_encodings():
5986 feature_encodings = feature_encodings ,
6087 feature_names = list (X .columns ),
6188 )
62- html_path = store_report (html_content , "test_categorical_feature_encodings.html" )
63- with sync_playwright () as p :
64- browser = p .chromium .launch ()
65- context = browser .new_context (record_video_dir = "tests/outputs/videos/" )
66- page = context .new_page ()
67- page .goto (f"file://{ html_path .resolve ()} " )
68- page .wait_for_selector (".importance-chart-container canvas" )
69- canvas = page .query_selector (".importance-chart-container canvas" )
70- assert canvas is not None , "Canvas not found in importance chart container"
71- assert canvas .is_visible (), "Canvas is not visible"
72- box = canvas .bounding_box ()
73- x = box ["x" ] + box ["width" ] / 2
74- y = box ["y" ] + box ["height" ] / 2
75- page .mouse .click (x , y )
76- page .wait_for_timeout (1000 ) # Wait to ensure video is recorded
77- context .close () # This will save the video file
78- browser .close ()
89+ return html_content
7990
8091
92+ @save_and_click_canvas_wrapper
8193def test_no_feature_encodings ():
8294 plugin = XaiflowPlugin ()
8395 html_content = plugin ._generate_html_content (
@@ -87,24 +99,9 @@ def test_no_feature_encodings():
8799 feature_encodings = None ,
88100 feature_names = ['Feature 1' , 'Feature 2' , 'Feature 3' ],
89101 )
90- html_path = store_report (html_content , "test_no_feature_encodings.html" )
91- with sync_playwright () as p :
92- browser = p .chromium .launch ()
93- context = browser .new_context (record_video_dir = "tests/outputs/videos/" )
94- page = context .new_page ()
95- page .goto (f"file://{ html_path .resolve ()} " )
96- page .wait_for_selector (".importance-chart-container canvas" )
97- canvas = page .query_selector (".importance-chart-container canvas" )
98- assert canvas is not None , "Canvas not found in importance chart container"
99- assert canvas .is_visible (), "Canvas is not visible"
100- box = canvas .bounding_box ()
101- x = box ["x" ] + box ["width" ] / 2
102- y = box ["y" ] + box ["height" ] / 2
103- page .mouse .click (x , y )
104- page .wait_for_timeout (1000 ) # Wait to ensure video is recorded
105- context .close () # This will save the video file
106- browser .close ()
102+ return html_content
107103
104+ @save_and_click_canvas_wrapper
108105def test_fix_previous_bug ():
109106 importanceData = {'features' :
110107 ['acv_score_canc_30d' ,
@@ -155,25 +152,43 @@ def test_fix_previous_bug():
155152 feature_encodings = featureEncodings ,
156153 feature_names = featureNames ,
157154 )
158- html_path = store_report (html_content , "test_fix_previous_bug.html" )
159- with sync_playwright () as p :
160- browser = p .chromium .launch ()
161- context = browser .new_context (record_video_dir = "tests/outputs/videos/" )
162- page = context .new_page ()
163- page .goto (f"file://{ html_path .resolve ()} " )
164- page .wait_for_selector (".importance-chart-container canvas" )
165- canvas = page .query_selector (".importance-chart-container canvas" )
166- assert canvas is not None , "Canvas not found in importance chart container"
167- assert canvas .is_visible (), "Canvas is not visible"
168- box = canvas .bounding_box ()
169- x = box ["x" ] + box ["width" ] / 2
170- y = box ["y" ] + box ["height" ] / 2
171- page .mouse .click (x , y )
172- page .wait_for_timeout (1000 ) # Wait to ensure video is recorded
173- context .close () # This will save the video file
174- browser .close ()
175-
176- # def test_importance_chart_canvas_click():
177- # html_path = Path("tests/outputs/test_no_feature_encodings.html")
178- # assert html_path.exists(), "Report HTML file does not exist. Run test_no_feature_encodings first."
155+ return html_content
156+
157+
158+ @save_and_click_canvas_wrapper
159+ def test_classification_case ():
160+ X , y = shap .datasets .adult (n_points = 200 )
161+
162+ # Identify categorical columns
163+ categorical_cols = [col for col in X .columns if X [col ].dtype == 'category' or X [col ].dtype == 'object' ]
164+ numeric_cols = [col for col in X .columns if col not in categorical_cols ]
165+
166+ label_encoders = {}
167+
168+ # Fill missing values manually
169+ for col in numeric_cols :
170+ X [col ] = X [col ].astype (float ).fillna (X [col ].mean ())
171+ for col in categorical_cols :
172+ le = LabelEncoder ()
173+ X [col + '_encoded' ] = le .fit_transform (X [col ].astype (str )) # convert to string in case of NaNs
174+ label_encoders [col ] = le # Save encoder if needed later
179175
176+ # Train model
177+ rfc = RandomForestClassifier ()
178+ rfc .fit (X , y )
179+ ex = shap .TreeExplainer (rfc )
180+ shap_values = ex (X )
181+ plugin = XaiflowPlugin ()
182+
183+ feature_encodings = {}
184+ for col in categorical_cols :
185+ feature_encodings [col + '_encoded' ] = dict (zip (range (len (label_encoders [col ].classes_ )), label_encoders [col ].classes_ ))
186+ html_content = plugin ._generate_html_content (
187+ importance_data = {'features' : list (X .columns ), 'values' : np .abs (shap_values .values ).mean (axis = 0 ).tolist ()},
188+ shap_values = shap_values .values ,
189+ feature_values = shap_values .data ,
190+ base_values = shap_values .base_values [0 ],
191+ feature_encodings = feature_encodings ,
192+ feature_names = list (X .columns ),
193+ )
194+ return html_content
0 commit comments