Skip to content

Commit 457d8dd

Browse files
committed
add group feature
1 parent 9cc8df3 commit 457d8dd

File tree

7 files changed

+121
-68
lines changed

7 files changed

+121
-68
lines changed

src/xaiflow/mlflow_plugin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def log_feature_importance_report(
3232
shap_values: Explanation,
3333
feature_encodings: Optional[Dict[str, Dict[int, str]]] = None,
3434
importance_values: List[float] | np.ndarray = None,
35+
group_labels: Optional[List[str]] = None,
3536
run_id: Optional[str] = None,
3637
artifact_path: str = "reports",
3738
report_name: str = "feature_importance_report.html",
@@ -44,6 +45,7 @@ def log_feature_importance_report(
4445
feature_names: List of feature names
4546
importance_values: List of importance values corresponding to features
4647
shap_values: Optional SHAP values matrix (samples x features)
48+
group_labels: Optional list of group labels for each sample
4749
run_id: MLflow run ID (uses active run if None)
4850
artifact_path: Path within MLflow artifacts to store the report
4951
report_name: Name of the HTML report file
@@ -74,6 +76,10 @@ def log_feature_importance_report(
7476
shap_values = shap_values[..., -1]
7577
base_values = float(base_values[-1])
7678

79+
if group_labels is not None:
80+
if len(group_labels) != shap_values.shape[0]:
81+
raise ValueError("group_labels length must match the number of samples in shap_values.")
82+
7783
# Use active run if no run_id provided
7884
if run_id is None:
7985
active_run = mlflow.active_run()
@@ -105,6 +111,7 @@ def log_feature_importance_report(
105111
html_content = self._generate_html_content(
106112
importance_data=importance_data,
107113
shap_values=shap_values,
114+
group_labels=group_labels or [], # Default to empty list if None
108115
feature_values=feature_values,
109116
base_values=base_values,
110117
feature_encodings=feature_encodings,
@@ -143,6 +150,7 @@ def _generate_html_content(
143150
importance_data: Dict[str, Any],
144151
shap_values: List[List[float]],
145152
feature_values: List[float] = None,
153+
group_labels: List[str] = None,
146154
base_values: List[float] = None,
147155
feature_encodings: Optional[Dict[str, Dict[int, str]]] = None,
148156
feature_names: List[str] = None
@@ -202,6 +210,7 @@ def _generate_html_content(
202210
timestamp=current_time,
203211
importance_data=importance_data, # Pass as Python dict
204212
shap_values=shap_values, # Pass as Python list
213+
group_labels=group_labels or [], # Pass as Python list or empty list
205214
feature_values=feature_values, # Pass as Python list or None
206215
base_values=base_values or [0] * 10, # Todo: fix this once we hand over numpy arrays
207216
feature_encodings=feature_encodings or {}, # Pass as optional dict

src/xaiflow/templates/assets/bundle.js

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/xaiflow/templates/assets/bundle.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/xaiflow/templates/components/ChartManager.svelte

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
baseValues: number[] | number; // Base values for SHAP calculations
1313
featureNames?: string[]; // Optional prop for feature names
1414
isHigherOutputBetter?: boolean; // Optional prop to determine if higher output is better
15+
groupLabels: string[]; // Optional prop for group labels
1516
}
1617
1718
let { importanceData,
@@ -21,11 +22,35 @@
2122
baseValues,
2223
featureNames,
2324
isHigherOutputBetter,
25+
groupLabels,
2426
}: Props = $props();
2527
2628
// Reactive state for selected label using $state
2729
let selectedLabel: string | null = $state(null);
2830
let showDeepDive = $state(false);
31+
let selectedGroup: string | null = $state(null);
32+
// Compute unique group labels
33+
let uniqueGroups: string[] = $derived(Array.from(new Set(groupLabels || [])));
34+
console.log('ChartManager: Loaded with props:', {
35+
importanceData,
36+
shapValues,
37+
featureValues,
38+
featureEncodings,
39+
baseValues,
40+
featureNames,
41+
isHigherOutputBetter,
42+
groupLabels
43+
});
44+
45+
// Compute selectedShapValues based on selectedGroup
46+
let selectedShapValues = $derived((selectedGroup && selectedGroup !== "" && selectedGroup !== "All")
47+
? shapValues.filter((_, idx) => groupLabels[idx] === selectedGroup)
48+
: shapValues);
49+
50+
let selectedFeatureValues = $derived((selectedGroup && selectedGroup !== "" && selectedGroup !== "All")
51+
? featureValues.filter((_, idx) => groupLabels[idx] === selectedGroup)
52+
: featureValues);
53+
console.log('ChartManager: selectedShapValues computed:', selectedShapValues);
2954
3055
console.log("ChartManager", importanceData);
3156
console.log('ChartManager: 1/4 command in file');
@@ -54,9 +79,22 @@
5479
</script>
5580

5681
<div class="chart-manager">
57-
<div style="display: flex; gap: 1.5rem; align-items: center; margin-bottom: 1.5rem;">
58-
<button type="button" on:click={() => showDeepDive = false} class:selected={!showDeepDive}>Charts</button>
59-
<button id="deepdive-button" type="button" on:click={() => showDeepDive = true} class:selected={showDeepDive}>Deep Dive</button>
82+
<div style="display: flex; gap: 1.5rem; align-items: center; margin-bottom: 1.5rem; justify-content: space-between;">
83+
<div style="display: flex; gap: 1.5rem; align-items: center;">
84+
<button type="button" on:click={() => showDeepDive = false} class:selected={!showDeepDive}>Charts</button>
85+
<button id="deepdive-button" type="button" on:click={() => showDeepDive = true} class:selected={showDeepDive}>Deep Dive</button>
86+
</div>
87+
{#if uniqueGroups.length > 0}
88+
<div style="margin-left: auto;">
89+
<label for="group-dropdown" style="margin-right: 0.5em; font-size: 1em;">Group:</label>
90+
<select id="group-dropdown" bind:value={selectedGroup} on:change={(e) => selectedGroup = e.target.value} style="font-size: 1em; padding: 0.3em 0.7em;">
91+
<option value="">All</option>
92+
{#each uniqueGroups as group}
93+
<option value={group}>{group}</option>
94+
{/each}
95+
</select>
96+
</div>
97+
{/if}
6098
</div>
6199
{#if !showDeepDive}
62100
<div class="charts-row">
@@ -75,8 +113,8 @@
75113
<h3>SHAP Values</h3>
76114
<div class="chart-container">
77115
<ScatterShapValues
78-
shapValues={shapValues}
79-
featureValues={featureValues}
116+
shapValues={selectedShapValues}
117+
featureValues={selectedFeatureValues}
80118
bind:selectedFeatureIndex={selectedFeatureIndex}
81119
bind:selectedFeature={selectedLabel}
82120
isHigherOutputBetter={true}
@@ -87,8 +125,8 @@
87125
</div>
88126
{:else}
89127
<DeepDiveManager
90-
shapValues={shapValues}
91-
featureValues={featureValues}
128+
shapValues={selectedShapValues}
129+
featureValues={selectedFeatureValues}
92130
selectedFeatureIndex={selectedFeatureIndex}
93131
selectedFeature={selectedLabel}
94132
baseValues={baseValues}

src/xaiflow/templates/components/DeepDiveManager.svelte

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
});
3333
let selectedObservationIndex = $state(0);
3434
let currentPage = $state(0);
35-
let totalObservations = shapValues.length;
35+
let totalObservations = $derived(shapValues.length);
3636
let filterText = $state("");
3737
3838
let allObservations = $derived(

src/xaiflow/templates/report.html

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ <h1>Xflow report by cloudexplain</h1>
289289
const baseValues = {{ base_values | safe }};
290290
const featureEncodings = {{ feature_encodings | safe }};
291291
const featureNames = {{ feature_names | safe }};
292+
const groupLabels = {{ group_labels | safe }};
292293

293294
// Initialize ChartManager with all props needed for both managers
294295
if (window.ChartManager && importanceData) {
@@ -306,7 +307,8 @@ <h1>Xflow report by cloudexplain</h1>
306307
featureValues: featureValues,
307308
baseValues: baseValues,
308309
featureEncodings: featureEncodings,
309-
featureNames: featureNames
310+
featureNames: featureNames,
311+
groupLabels: groupLabels,
310312
}
311313
});
312314
console.log('ChartManager with DeepDiveManager mounted successfully!');

tests/test_mlflow_plugin.py

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -207,59 +207,6 @@ def test_no_feature_encodings():
207207
)
208208
return html_content
209209

210-
@save_and_click_canvas_wrapper
211-
def test_fix_previous_bug():
212-
importanceData = {'features':
213-
['acv_score_canc_30d',
214-
'avg_canc_dealer_no_weighted',
215-
'ctr_usa_sec_inc_voice_a6m',
216-
'avg_canc_reseller_id_weighted',
217-
'ctr_usa_kb_data_usg_a3m',
218-
'ctr_sales_channel_current',
219-
'ctr_cancellations_per_year',
220-
'avg_vvl_reseller_id_weighted',
221-
'ctr_start_days',
222-
'ctr_min_duration_date_crm_days',
223-
'rlz',
224-
'vvl_l_event_days',
225-
'avg_vvl_sales_channel',
226-
'ctr_dealer_no_current',
227-
'avg_canc_sales_channel',
228-
'prt_cancellation_page_visit_90d_count',
229-
'acv_score_vvl_30d'],
230-
'values': [0.5000000000000614,
231-
0.49999999999993844]}
232-
shapValues = [[0.05231021109253422, -0.05231021109253736], [0.0073606489440402965, -0.007360648944034653], [-0.01633880222219225, 0.016338802222170094], [0.012322311243639975, -0.012322311243637033], [-0.004445322661143976, 0.004445322661143468], [0.0009611405151175154, -0.0009611405151178431], [0.005596997502669034, -0.0055969975026683915], [-0.0008618250588141368, 0.0008618250587932731], [0.0016991238750754237, -0.0016991238750824476], [0.0048252568432011304, -0.004825256843199152], [-0.00038499217151075256, 0.00038499217151299176], [0.005172501948575322, -0.005172501948575318], [-0.003383349580534422, 0.003383349580535079], [-0.017147577240666855, 0.017147577240670973], [0.008064862968425773, -0.008064862968423504], [0.0018500348673166761, -0.0018500348673163927], [0.006529750924148127, -0.006529750924151195]]
233-
featureValues = [0.03421833738684654, 0.022704629679359795, 15.0, 0.022704629679359795, 30193.0, 241.0, 0.0, 0.022739316468840073, 951.5416666666666, 2912717.0, 30.0, 9999.0, 0.019356054262267625, 9.0, 0.02191634567074192, 0.0, 0.2959745228290558]
234-
baseValues = [0.9058690282100686, 0.09413097178993132]
235-
featureEncodings = None
236-
featureNames = ['acv_score_canc_30d',
237-
'avg_canc_dealer_no_weighted',
238-
'ctr_usa_sec_inc_voice_a6m',
239-
'avg_canc_reseller_id_weighted',
240-
'ctr_usa_kb_data_usg_a3m',
241-
'ctr_sales_channel_current',
242-
'ctr_cancellations_per_year',
243-
'avg_vvl_reseller_id_weighted',
244-
'ctr_start_days',
245-
'ctr_min_duration_date_crm_days',
246-
'rlz',
247-
'vvl_l_event_days',
248-
'avg_vvl_sales_channel',
249-
'ctr_dealer_no_current',
250-
'avg_canc_sales_channel',
251-
'prt_cancellation_page_visit_90d_count',
252-
'acv_score_vvl_30d']
253-
plugin = XaiflowPlugin()
254-
html_content = plugin._generate_html_content(
255-
importance_data=importanceData,
256-
shap_values=shapValues,
257-
feature_values=featureValues,
258-
feature_encodings=featureEncodings,
259-
feature_names=featureNames,
260-
)
261-
return html_content
262-
263210

264211
def test_classification_case(mocker):
265212
X, y = shap.datasets.adult(n_points=200)
@@ -322,4 +269,61 @@ def __exit__(self, exc_type, exc_val, exc_tb):
322269
feature_names=list(X.columns),
323270
)
324271
html_content_click_test(Path(output_path))
272+
# return html_content
273+
274+
275+
def test_classification_case_check_list_feature(mocker):
276+
X, y = shap.datasets.adult(n_points=200)
277+
278+
# Identify categorical columns
279+
categorical_cols = [col for col in X.columns if X[col].dtype == 'category' or X[col].dtype == 'object']
280+
numeric_cols = [col for col in X.columns if col not in categorical_cols]
281+
282+
label_encoders = {}
283+
284+
# Fill missing values manually
285+
for col in numeric_cols:
286+
X[col] = X[col].astype(float).fillna(X[col].mean())
287+
for col in categorical_cols:
288+
le = LabelEncoder()
289+
X[col + '_encoded'] = le.fit_transform(X[col].astype(str)) # convert to string in case of NaNs
290+
label_encoders[col] = le # Save encoder if needed later
291+
292+
# Train model
293+
rfc = RandomForestClassifier()
294+
rfc.fit(X, y)
295+
ex = shap.TreeExplainer(rfc)
296+
shap_values = ex(X)
297+
plugin = XaiflowPlugin()
298+
299+
feature_encodings = {}
300+
for col in categorical_cols:
301+
feature_encodings[col + '_encoded'] = dict(zip(range(len(label_encoders[col].classes_)), label_encoders[col].classes_))
302+
experiment_name = "dummytest"
303+
mlflow.set_experiment(experiment_name=experiment_name)
304+
305+
output_path = f"tests/outputs/test_classification_case_check_list_feature.html"
306+
class DummyTmpFile:
307+
name = output_path
308+
def __enter__(self):
309+
self.name = output_path
310+
# import pdb; pdb.set_trace() # Debugging breakpoint
311+
return self
312+
def __exit__(self, exc_type, exc_val, exc_tb):
313+
pass
314+
315+
mocker.patch("tempfile.NamedTemporaryFile", return_value=DummyTmpFile())
316+
mocker.patch("os.unlink") # Prevent deletion
317+
318+
# Optionally patch mlflow.log_artifact if you want to avoid real logging
319+
mocker.patch("mlflow.log_artifact")
320+
321+
with mlflow.start_run(run_name="auto_mpg_test"):
322+
plugin.log_feature_importance_report(
323+
shap_values=shap_values,
324+
feature_encodings=feature_encodings,
325+
feature_names=list(X.columns),
326+
group_labels=["Group 1", "Group 2", "Group 3", "Group 4"] * int(len(shap_values) / 4) # Example group labels
327+
)
328+
html_content_click_test(Path(output_path))
325329
# return html_content

0 commit comments

Comments
 (0)