Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ __pycache__
_version.py
build/
.coverage
.vscode
tmp_testing*
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dynamic = [
"version",
]
dependencies = [
"astropy",
"awkward",
"awkward-pandas",
"joblib",
Expand Down
76 changes: 76 additions & 0 deletions src/eventdisplay_ml/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,82 @@ def energy_in_bins(df_chunk, bins):
return df_chunk["e_bin"]


def energy_interpolation_bins(df_chunk, bins):
"""Compute neighboring energy bins and interpolation weights per event.

Parameters
----------
df_chunk : pandas.DataFrame
Chunk containing an ``Erec`` column in TeV.
bins : list[dict | None]
Energy bin definitions with ``E_min`` and ``E_max`` in log10(E/TeV).

Returns
-------
tuple[np.ndarray, np.ndarray, np.ndarray]
``(e_bin_lo, e_bin_hi, e_alpha)`` where ``e_alpha`` is in ``[0, 1]``.
Invalid events get ``e_bin_lo = e_bin_hi = -1`` and ``e_alpha = 0``.
"""
centers = np.array([(b["E_min"] + b["E_max"]) / 2 if b is not None else np.nan for b in bins])
n_events = len(df_chunk)

e_bin_lo = np.full(n_events, -1, dtype=np.int32)
e_bin_hi = np.full(n_events, -1, dtype=np.int32)
e_alpha = np.zeros(n_events, dtype=np.float32)

if n_events == 0 or np.isnan(centers).all():
return e_bin_lo, e_bin_hi, e_alpha

valid_event_mask = df_chunk["Erec"].to_numpy() > 0
if not np.any(valid_event_mask):
return e_bin_lo, e_bin_hi, e_alpha

valid_center_idx = np.flatnonzero(~np.isnan(centers))
valid_centers = centers[valid_center_idx]

order = np.argsort(valid_centers)
sorted_idx = valid_center_idx[order]
sorted_centers = valid_centers[order]

log_e_all = np.full(n_events, np.nan, dtype=np.float64)
log_e_all[valid_event_mask] = np.log10(df_chunk.loc[valid_event_mask, "Erec"].to_numpy())
valid_event_idx = np.flatnonzero(valid_event_mask)
log_e = log_e_all[valid_event_mask]

if len(sorted_idx) == 1:
only_idx = int(sorted_idx[0])
e_bin_lo[valid_event_idx] = only_idx
e_bin_hi[valid_event_idx] = only_idx
return e_bin_lo, e_bin_hi, e_alpha

insert_pos = np.searchsorted(sorted_centers, log_e, side="left")

lo_pos = np.clip(insert_pos - 1, 0, len(sorted_centers) - 1)
hi_pos = np.clip(insert_pos, 0, len(sorted_centers) - 1)

left_mask = insert_pos <= 0
right_mask = insert_pos >= len(sorted_centers)
lo_pos[left_mask] = 0
hi_pos[left_mask] = 0
lo_pos[right_mask] = len(sorted_centers) - 1
hi_pos[right_mask] = len(sorted_centers) - 1

lo_centers = sorted_centers[lo_pos]
hi_centers = sorted_centers[hi_pos]
denom = hi_centers - lo_centers

alpha = np.zeros_like(log_e, dtype=np.float64)
interp_mask = denom > 0
alpha[interp_mask] = (log_e[interp_mask] - lo_centers[interp_mask]) / denom[interp_mask]
alpha = np.clip(alpha, 0.0, 1.0)

e_bin_lo[valid_event_idx] = sorted_idx[lo_pos]
e_bin_hi[valid_event_idx] = sorted_idx[hi_pos]
e_alpha[valid_event_idx] = alpha.astype(np.float32)

return e_bin_lo, e_bin_hi, e_alpha


def print_variable_statistics(df):
"""
Print min, max, mean, and RMS for each variable in the DataFrame.
Expand Down
57 changes: 44 additions & 13 deletions src/eventdisplay_ml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from eventdisplay_ml import data_processing, diagnostic_utils, features, utils
from eventdisplay_ml.data_processing import (
energy_in_bins,
energy_interpolation_bins,
flatten_feature_data,
zenith_in_bins,
)
Expand Down Expand Up @@ -100,7 +100,7 @@ def load_classification_models(model_prefix, model_name):
pattern = f"{model_prefix.name}_ebin*.joblib"
files = sorted(model_dir_path.glob(pattern))

_logger.info("Loading classification models")
_logger.info(f"Loading classification models from {files}")
for file in files:
match = re.search(r"_ebin(\d+)\.joblib$", file.name)
if not match:
Expand Down Expand Up @@ -359,13 +359,18 @@ def apply_classification_models(df, model_configs, threshold_keys):
tel_config = model_configs.get("tel_config")
n_tel = tel_config["max_tel_id"] + 1 if tel_config else 4

for e_bin, group_df in df.groupby("e_bin"):
e_bin = int(e_bin)
if e_bin == -1:
_logger.warning("Skipping events with e_bin = -1")
for bin_pair, group_df in df.groupby(["e_bin_lo", "e_bin_hi"], dropna=False):
e_bin_lo, e_bin_hi = int(bin_pair[0]), int(bin_pair[1])
if e_bin_lo == -1 or e_bin_hi == -1:
_logger.warning("Skipping events with invalid energy interpolation bins")
continue

_logger.info(f"Processing {len(group_df)} events with bin={e_bin}")
_logger.info(
"Processing %d events with interpolation bins (%d, %d)",
len(group_df),
e_bin_lo,
e_bin_hi,
)

flatten_data = flatten_feature_data(
group_df,
Expand All @@ -376,14 +381,35 @@ def apply_classification_models(df, model_configs, threshold_keys):
observatory=model_configs.get("observatory", "veritas"),
preview_rows=model_configs.get("preview_rows", 20),
)
model = models[e_bin]["model"]
flatten_data = flatten_data.reindex(columns=models[e_bin]["features"])
class_probs = model.predict_proba(flatten_data)[:, 1]
model_lo = models[e_bin_lo]["model"]
model_hi = models[e_bin_hi]["model"]
flatten_lo = flatten_data.reindex(columns=models[e_bin_lo]["features"])
flatten_hi = flatten_data.reindex(columns=models[e_bin_hi]["features"])

class_probs_lo = model_lo.predict_proba(flatten_lo)[:, 1]
if e_bin_lo == e_bin_hi:
class_probs = class_probs_lo
else:
class_probs_hi = model_hi.predict_proba(flatten_hi)[:, 1]
alpha = group_df["e_alpha"].to_numpy(dtype=np.float32)
class_probs = (1.0 - alpha) * class_probs_lo + alpha * class_probs_hi
class_probability[group_df.index] = class_probs

thresholds = models[e_bin].get("thresholds", {})
for eff, threshold in thresholds.items():
thresholds_lo = models[e_bin_lo].get("thresholds", {})
thresholds_hi = models[e_bin_hi].get("thresholds", {})
for eff in threshold_keys:
if eff in is_gamma:
thr_lo = thresholds_lo.get(eff)
if thr_lo is None:
continue
if e_bin_lo == e_bin_hi:
threshold = thr_lo
else:
thr_hi = thresholds_hi.get(eff)
if thr_hi is None:
continue
alpha = group_df["e_alpha"].to_numpy(dtype=np.float32)
threshold = (1.0 - alpha) * thr_lo + alpha * thr_hi
is_gamma[eff][group_df.index] = (class_probs >= threshold).astype(np.uint8)

return class_probability, is_gamma
Expand Down Expand Up @@ -469,7 +495,12 @@ def process_file_chunked(analysis_type, model_configs):
# index out-of-bounds when indexing chunk-sized output arrays
df_chunk = df_chunk.reset_index(drop=True)
if analysis_type == "classification":
df_chunk["e_bin"] = energy_in_bins(df_chunk, model_configs["energy_bins_log10_tev"])
e_bin_lo, e_bin_hi, e_alpha = energy_interpolation_bins(
df_chunk, model_configs["energy_bins_log10_tev"]
)
df_chunk["e_bin_lo"] = e_bin_lo
df_chunk["e_bin_hi"] = e_bin_hi
df_chunk["e_alpha"] = e_alpha
df_chunk["ze_bin"] = zenith_in_bins(
90.0 - df_chunk["ArrayPointing_Elevation"].values,
model_configs["zenith_bins_deg"],
Expand Down
2 changes: 1 addition & 1 deletion src/eventdisplay_ml/scripts/optimize_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

_ALPHA = 1.0 / 6.0
# expect Crab spectrum for input signal rate
_CRAB_INDEX = 2.6
_CRAB_INDEX = 2.63


def _validate_source_index(source_index):
Expand Down
72 changes: 72 additions & 0 deletions tests/test_classification_apply_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Tests for energy-interpolated classification apply."""

import numpy as np
import pandas as pd

from eventdisplay_ml import data_processing, models


class DummyXGBClassifier:
"""Simple classifier returning a fixed gamma probability."""

def __init__(self, proba):
self._proba = float(proba)

def predict_proba(self, x_data):
"""Return constant class probabilities for all rows."""
n_rows = len(x_data)
p1 = np.full(n_rows, self._proba, dtype=np.float32)
p0 = 1.0 - p1
return np.column_stack([p0, p1])


def test_energy_interpolation_bins_linear_weights():
"""Interpolation bins should bracket log-energy with linear alpha."""
df = pd.DataFrame({"Erec": [10**-0.75, 10**0.25, 10**0.8]})
bins = [
{"E_min": -1.0, "E_max": -0.5}, # center -0.75
{"E_min": -0.5, "E_max": 0.0}, # center -0.25
{"E_min": 0.0, "E_max": 0.5}, # center 0.25
]

e_bin_lo, e_bin_hi, e_alpha = data_processing.energy_interpolation_bins(df, bins)

np.testing.assert_array_equal(e_bin_lo, np.array([0, 1, 2], dtype=np.int32))
np.testing.assert_array_equal(e_bin_hi, np.array([0, 2, 2], dtype=np.int32))
np.testing.assert_allclose(e_alpha, np.array([0.0, 1.0, 0.0], dtype=np.float32), atol=1e-7)


def test_apply_classification_models_interpolates_probabilities_and_thresholds(monkeypatch):
"""Classification apply should linearly interpolate between neighboring energy-bin models."""
df = pd.DataFrame(
{
"e_bin_lo": [0, 0],
"e_bin_hi": [1, 1],
"e_alpha": [0.25, 0.75],
"dummy": [1.0, 2.0],
}
)

model_configs = {
"models": {
0: {
"model": DummyXGBClassifier(0.2),
"features": ["dummy"],
"thresholds": {50: 0.4},
},
1: {
"model": DummyXGBClassifier(1.0),
"features": ["dummy"],
"thresholds": {50: 0.8},
},
}
}

monkeypatch.setattr(models, "flatten_feature_data", lambda *args, **kwargs: df[["dummy"]])

class_probability, is_gamma = models.apply_classification_models(df, model_configs, [50])

np.testing.assert_allclose(class_probability, np.array([0.4, 0.8], dtype=np.float32), atol=1e-7)

# Thresholds are interpolated the same way: [0.5, 0.7]
np.testing.assert_array_equal(is_gamma[50], np.array([0, 1], dtype=np.uint8))
107 changes: 107 additions & 0 deletions tests/test_optimize_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Tests for optimize_classification grid and interpolation helpers."""

import numpy as np
import pytest

from eventdisplay_ml.scripts.optimize_classification import (
_CRAB_INDEX,
_build_fine_rate_grid,
_interpolate_efficiency_surface,
_inverse_cosine_to_zenith,
_spectral_reweight_factor,
)


def test_build_fine_rate_grid_interpolates_rates_on_requested_axes():
"""Interpolate rate surfaces onto a finer energy and 1/cos(ze) grid."""
energy = np.array([0.0, 1.0, 0.0, 1.0], dtype=float)
zenith = np.array([0.0, 0.0, 60.0, 60.0], dtype=float)
inverse_cosine_zenith = 1.0 / np.cos(np.deg2rad(zenith))

on_rate = 10.0 + 2.0 * energy + 3.0 * inverse_cosine_zenith
background_rate = 4.0 + energy + 0.5 * inverse_cosine_zenith

fine_grid = _build_fine_rate_grid(
energy,
zenith,
on_rate,
background_rate,
energy_bin_width=0.5,
inverse_cosine_zenith_bin_width=0.5,
)

expected_energy_axis = np.array([0.0, 0.5, 1.0], dtype=float)
expected_inverse_cosine_zenith_axis = np.array([1.0, 1.5, 2.0], dtype=float)
expected_zenith_axis = _inverse_cosine_to_zenith(expected_inverse_cosine_zenith_axis)

assert np.allclose(fine_grid["energy_axis"], expected_energy_axis)
assert np.allclose(fine_grid["zenith_axis"], expected_zenith_axis)

energy_mesh, inverse_cosine_zenith_mesh = np.meshgrid(
expected_energy_axis,
expected_inverse_cosine_zenith_axis,
indexing="xy",
)
expected_on_rate = 10.0 + 2.0 * energy_mesh.ravel() + 3.0 * inverse_cosine_zenith_mesh.ravel()
expected_background_rate = 4.0 + energy_mesh.ravel() + 0.5 * inverse_cosine_zenith_mesh.ravel()

assert np.allclose(fine_grid["on_rate"], expected_on_rate)
assert np.allclose(fine_grid["background_rate"], expected_background_rate)


def test_interpolate_efficiency_surface_uses_energy_and_cos_zenith():
"""Interpolate efficiency on energy and cos(ze), clipping at model edges."""
model_energy_axis = np.array([0.0, 1.0], dtype=float)
model_zenith_axis = np.array([0.0, 60.0], dtype=float)
model_cos_zenith_axis = np.cos(np.deg2rad(model_zenith_axis))
efficiency_surface = np.array(
[
0.2 + 0.1 * model_energy_axis + 0.3 * model_cos_zenith_axis[0],
0.2 + 0.1 * model_energy_axis + 0.3 * model_cos_zenith_axis[1],
],
dtype=float,
)

target_energy = np.array([0.5, -1.0], dtype=float)
target_zenith = np.array([np.rad2deg(np.arccos(0.75)), 80.0], dtype=float)
interpolated = _interpolate_efficiency_surface(
model_energy_axis,
model_zenith_axis,
efficiency_surface,
target_energy,
target_zenith,
)

expected = np.array(
[
0.2 + 0.1 * 0.5 + 0.3 * 0.75,
0.2 + 0.1 * 0.0 + 0.3 * model_cos_zenith_axis.min(),
],
dtype=float,
)

assert np.allclose(interpolated, expected)


def test_spectral_reweight_factor_is_unity_for_crab_index():
"""Crab-to-Crab reweighting should keep rates unchanged."""
log10_energy = np.array([-1.0, 0.0, 1.0], dtype=float)
weights = _spectral_reweight_factor(log10_energy, _CRAB_INDEX)
assert np.allclose(weights, np.ones_like(log10_energy))


def test_spectral_reweight_factor_reweights_power_law_relative_to_crab():
"""Reweight factor follows E^-(index - crab_index), normalized at 1 TeV."""
log10_energy = np.array([-1.0, 0.0, 1.0], dtype=float)
source_index = 3.63
expected = np.array([10.0, 1.0, 0.1], dtype=float)
weights = _spectral_reweight_factor(log10_energy, source_index)
assert np.allclose(weights, expected)


def test_spectral_reweight_factor_rejects_out_of_range_indices():
"""Only source indices in [2, 5] are accepted."""
with pytest.raises(ValueError, match=r"within \[2, 5\]"):
_spectral_reweight_factor(np.array([0.0]), 1.9)
with pytest.raises(ValueError, match=r"within \[2, 5\]"):
_spectral_reweight_factor(np.array([0.0]), 5.1)