Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/eventdisplay_ml/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,14 +730,19 @@ def load_training_data(model_configs, file_list, analysis_type):
observatory=model_configs.get("observatory", "veritas"),
)
if analysis_type == "stereo_analysis":
df_flat["MCxoff"] = _to_numpy_1d(df["MCxoff"], np.float32)
df_flat["MCyoff"] = _to_numpy_1d(df["MCyoff"], np.float32)
df_flat["MCe0"] = np.log10(_to_numpy_1d(df["MCe0"], np.float32))
new_cols = {
"MCxoff": _to_numpy_1d(df["MCxoff"], np.float32),
"MCyoff": _to_numpy_1d(df["MCyoff"], np.float32),
"MCe0": np.log10(_to_numpy_1d(df["MCe0"], np.float32)),
}
elif analysis_type == "classification":
df_flat["ze_bin"] = zenith_in_bins(
90.0 - _to_numpy_1d(df["ArrayPointing_Elevation"], np.float32),
model_configs.get("zenith_bins_deg", []),
)
new_cols = {
"ze_bin": zenith_in_bins(
90.0 - _to_numpy_1d(df["ArrayPointing_Elevation"], np.float32),
model_configs.get("zenith_bins_deg", []),
)
}
df_flat = pd.concat([df_flat, pd.DataFrame(new_cols, index=df_flat.index)], axis=1)
Comment thread
GernotMaier marked this conversation as resolved.
Outdated

dfs.append(df_flat)

Expand Down
144 changes: 144 additions & 0 deletions src/eventdisplay_ml/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def evaluate_regression_model(model, x_test, y_test, df, x_cols, y_data, name):
feature_importance(model, x_cols, y_data.columns, name)
if name == "xgboost":
shap_feature_importance(model, x_test, y_data.columns)
# Optional for now
# shap_feature_importance_by_energy(
# model, x_test, df, y_test, y_data.columns
# )

df_pred = pd.DataFrame(y_pred, columns=target_features("stereo_analysis"))
calculate_resolution(
Expand Down Expand Up @@ -253,3 +257,143 @@ def shap_feature_importance(model, x_data, target_names, max_points=1000, n_top=
for j in idx[:n_top]:
if j < n_features:
_logger.info(f"{x_data.columns[j]:25s} {imp[j]:.6e}")


def shap_feature_importance_by_energy(
model,
x_test,
df,
y_test,
target_names,
log_e_min=-2.0,
log_e_max=2.5,
n_bins=9,
max_points=1000,
n_top=5,
):
"""Calculate SHAP feature importance for each energy bin.

Computes SHAP values separately for events in different energy ranges,
allowing analysis of feature importance as a function of energy.
Uses the same energy binning as calculate_resolution for consistency.
Outputs results in tabular format for easy comparison across energy bins.
"""
# Extract energy values and create bins
mce0_values = df.loc[y_test.index, "MCe0"].values
bins = np.linspace(log_e_min, log_e_max, n_bins + 1)
bin_indices = np.digitize(mce0_values, bins)
Comment thread
GernotMaier marked this conversation as resolved.
Outdated

n_features = len(x_test.columns)
n_targets = len(target_names)

# Store importance values for each target across all bins
target_importance_data = {target: {} for target in target_names}
bin_info = []

# Collect stratified samples for all bins, then compute SHAP once
sampled_frames = []
sampled_bin_labels = []

for bin_idx in range(1, n_bins + 1):
mask = bin_indices == bin_idx
n_events = mask.sum()

if n_events == 0:
continue

bin_lower = bins[bin_idx - 1]
bin_upper = bins[bin_idx]
mean_log_e = mce0_values[mask].mean()

bin_label = f"LogE={mean_log_e:.2f}"
bin_info.append(
{
"label": bin_label,
"mean_log_e": mean_log_e,
"n_events": n_events,
"range": f"[{bin_lower:.2f}, {bin_upper:.2f}]",
Comment thread
GernotMaier marked this conversation as resolved.
Outdated
}
)

x_bin = x_test.iloc[mask]
n_sample = min(len(x_bin), max_points)
x_sample = x_bin.sample(n=n_sample, random_state=None)

sampled_frames.append(x_sample)
sampled_bin_labels.extend([bin_label] * len(x_sample))

if not sampled_frames:
_logger.info("No events found in any energy bin for SHAP calculation.")
return

x_sampled_all = pd.concat(sampled_frames, axis=0)
dmatrix = xgb.DMatrix(x_sampled_all)
shap_vals = model.get_booster().predict(dmatrix, pred_contribs=True)
shap_vals = shap_vals.reshape(len(x_sampled_all), n_targets, n_features + 1)

# Aggregate SHAP importance per bin from the single SHAP run
sampled_bin_labels = np.array(sampled_bin_labels)
for i, target in enumerate(target_names):
target_shap = shap_vals[:, i, :-1]
for info in bin_info:
bin_label = info["label"]
bin_mask = sampled_bin_labels == bin_label
if not np.any(bin_mask):
continue

imp = np.abs(target_shap[bin_mask]).mean(axis=0)
for j, feature_name in enumerate(x_test.columns):
if feature_name not in target_importance_data[target]:
target_importance_data[target][feature_name] = {}
target_importance_data[target][feature_name][bin_label] = imp[j]

# Create and display tables for each target
_logger.info(f"\n{'=' * 100}")
_logger.info("SHAP Feature Importance by Energy Bin (Tabular Format)")
_logger.info(f"Calculated over {n_bins} bins [{log_e_min}, {log_e_max}]")
_logger.info(f"{'=' * 100}")

# Display bin information
_logger.info("\nEnergy Bin Information:")
for info in bin_info:
_logger.info(f" {info['label']:12s}: Range {info['range']:15s}, N = {info['n_events']:6d}")

for target in target_names:
_logger.info(f"\n\n=== SHAP Importance for {target} ===")

# Find top N features in each bin, then take union of all top features
all_top_features = set()
for info in bin_info:
bin_label = info["label"]
# Get importance values for this bin
bin_importance = {
feature: values.get(bin_label, 0)
for feature, values in target_importance_data[target].items()
}
# Get top N features for this bin
top_in_bin = sorted(bin_importance.items(), key=lambda x: x[1], reverse=True)[:n_top]
all_top_features.update([f[0] for f in top_in_bin])

# Sort features by their average importance across all bins
feature_avg_importance = {}
for feature_name in all_top_features:
values = [
target_importance_data[target][feature_name].get(info["label"], 0)
for info in bin_info
]
feature_avg_importance[feature_name] = np.mean(values)

sorted_features = sorted(feature_avg_importance.items(), key=lambda x: x[1], reverse=True)

# Build DataFrame with all features that were top N in at least one bin
data_rows = []
for feature_name, _ in sorted_features:
row = {"Feature": feature_name}
for info in bin_info:
bin_label = info["label"]
value = target_importance_data[target][feature_name].get(bin_label, np.nan)
row[bin_label] = value
data_rows.append(row)

df_table = pd.DataFrame(data_rows)
_logger.info(f"\n{df_table.to_markdown(index=False, floatfmt='.4e')}")
Loading