Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ dependencies:
- numpy
- pandas
- pre-commit
- pytest
- pytest-cov
- pytest-mock
- scikit-learn
- scipy
- shellcheck
- tabulate
- towncrier
- uproot
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = python
149 changes: 137 additions & 12 deletions python/applyXGBoostforDirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,92 @@
import argparse
import logging
import os
import time

import joblib
import numpy as np
import pandas as pd
import uproot
from trainXGBoostforDirection import TRAINING_VARIABLES
from training_variables import xgb_training_variables

logging.basicConfig(level=logging.INFO)
_logger = logging.getLogger("applyXGBoostforDirection")


def parse_image_selection(image_selection_str):
"""
Parse the image_selection parameter.

Can be either:
- Bit-coded value (e.g., 14 = 0b1110 = telescopes 1,2,3)
- Comma-separated indices (e.g., "1,2,3")

Returns a list of telescope indices.
"""
if not image_selection_str:
return None

# Parse as comma-separated indices
if "," in image_selection_str:
try:
indices = [int(x.strip()) for x in image_selection_str.split(",")]
_logger.info(f"Image selection indices: {indices}")
return indices
except ValueError:
pass

# Parse as bit-coded value
try:
bit_value = int(image_selection_str)
# Extract bit positions (0-indexed)
indices = [i for i in range(4) if (bit_value >> i) & 1]
_logger.info(f"Image selection from bit-coded value {bit_value}: {indices}")
return indices
except ValueError:
raise ValueError(
f"Invalid image_selection format: {image_selection_str}. "
"Use bit-coded value (e.g., 14) or comma-separated indices (e.g., '1,2,3')"
)


def filter_by_telescope_selection(df, selected_indices):
"""
Build a selection mask for events where all selected telescope indices
are present in DispTelList_T, OR events with 4 telescopes.

IMPORTANT: Does not drop events. Returns a boolean mask of length len(df).
Use this mask to process only selected events while preserving original order.

Parameters
----------
- df: DataFrame with DispTelList_T column
- selected_indices: List of selected telescope indices (or None)

Returns
-------
pd.Series[bool]: selection mask (True = selected, False = unselected)
"""
if selected_indices is None:
return pd.Series([True] * len(df), index=df.index)

_logger.info(f"Filtering events for telescope selection: {selected_indices}")

def has_selected_telescopes(tel_list):
"""Check if event has all selected telescopes or has 4 telescopes."""
tel_set = set(tel_list)
has_all_selected = all(idx in tel_set for idx in selected_indices)
has_four_telescopes = len(tel_list) == 4
return has_all_selected or has_four_telescopes

mask = df["DispTelList_T"].apply(has_selected_telescopes)

_logger.info(
f"Selection: {mask.sum()} of {len(df)} events "
f"({100 * mask.sum() / len(df):.1f}%)"
)
return mask


def load_all_events(input_file, max_events=None):
"""
Load all events from the input ROOT file without filtering by n_tel.
Expand All @@ -44,7 +119,7 @@ def load_all_events(input_file, max_events=None):
"Yoff_intersect",
"fpointing_dx",
"fpointing_dy",
] + [var for var in TRAINING_VARIABLES]
] + [var for var in xgb_training_variables()]

with uproot.open(input_file) as root_file:
if "data" not in root_file:
Expand Down Expand Up @@ -134,7 +209,7 @@ def flatten_data_vectorized(df, n_tel, training_variables):
return df_flat


def apply_models(df, model_dir):
def apply_models(df, model_dir, selection_mask=None):
"""
Apply trained XGBoost models to all events in the DataFrame.
Returns arrays of predicted Xoff and Yoff for each event.
Expand All @@ -161,8 +236,9 @@ def apply_models(df, model_dir):
pred_xoff = np.full(n_events, np.nan)
pred_yoff = np.full(n_events, np.nan)

# Group events by DispNImages for batch processing
grouped = df.groupby("DispNImages")
# Group selected events (if mask provided) by DispNImages for batch processing
df_to_group = df[selection_mask] if selection_mask is not None else df
grouped = df_to_group.groupby("DispNImages")

for n_tel, group_df in grouped:
n_tel = int(n_tel)
Expand All @@ -175,7 +251,7 @@ def apply_models(df, model_dir):
_logger.info(f"Processing {len(group_df)} events with n_tel={n_tel}")

# Add fpointing_dx and fpointing_dy to the training variables for flattening
training_vars_with_pointing = TRAINING_VARIABLES + [
training_vars_with_pointing = xgb_training_variables() + [
"fpointing_dx",
"fpointing_dy",
]
Expand All @@ -198,7 +274,19 @@ def apply_models(df, model_dir):
pred_xoff[idx] = predictions[i, 0]
pred_yoff[idx] = predictions[i, 1]

# Fill unselected events with -999 as requested
if selection_mask is not None:
unselected = ~selection_mask.values
pred_xoff[unselected] = -999
pred_yoff[unselected] = -999
_logger.info(
f"Filled {unselected.sum()} unselected events with -999 for predictions"
)

_logger.info("Predictions complete")
_logger.info(
f"Prediction arrays length: {len(pred_xoff)} (input events: {len(df)})"
)
return pred_xoff, pred_yoff


Expand All @@ -222,23 +310,60 @@ def write_output_root_file(output_file, pred_xoff, pred_yoff):

def main():
parser = argparse.ArgumentParser(
description=("Apply XGBoost Multi-Target BDTs for Direction Reconstruction")
description=("Apply XGBoost Multi-Target BDTs for Stereo Reconstruction")
)
parser.add_argument(
"--input-file",
required=True,
metavar="INPUT.root",
help="Path to input mscw ROOT file",
)
parser.add_argument(
"--model-dir",
required=True,
metavar="MODEL_DIR",
help="Directory containing XGBoost models",
)
parser.add_argument(
"--output-file",
required=True,
metavar="OUTPUT.root",
help="Output ROOT file path for predictions",
)
parser.add_argument(
"--image-selection",
type=str,
default=None,
help=(
"Optional telescope selection. Can be bit-coded (e.g., 14 for telescopes 1,2,3) "
"or comma-separated indices (e.g., '1,2,3'). "
"Keeps events with all selected telescopes or 4-telescope events."
),
)
parser.add_argument("input_file", help="Input mscw ROOT file.")
parser.add_argument("model_dir", help="Directory with XGBoost models.")
parser.add_argument("output_file", help="Output ROOT file with applied BDTs.")
args = parser.parse_args()

start_time = time.time()
_logger.info("--- XGBoost Multi-Target Direction Evaluation ---")
_logger.info(f"Input file: {args.input_file}")
_logger.info(f"Model directory: {args.model_dir}")
_logger.info(f"Output file: {args.output_file}")
if args.image_selection:
_logger.info(f"Image selection: {args.image_selection}")

df = load_all_events(args.input_file, max_events=None)
pred_xoff, pred_yoff = apply_models(df, args.model_dir)

selection_mask = None
if args.image_selection:
selected_indices = parse_image_selection(args.image_selection)
selection_mask = filter_by_telescope_selection(df, selected_indices)
else:
_logger.info("No image selection applied")

pred_xoff, pred_yoff = apply_models(df, args.model_dir, selection_mask)
write_output_root_file(args.output_file, pred_xoff, pred_yoff)

_logger.info("Processing complete.")
elapsed_time = time.time() - start_time
_logger.info(f"Processing complete. Total time: {elapsed_time:.2f} seconds")


if __name__ == "__main__":
Expand Down
100 changes: 100 additions & 0 deletions python/tests/test_applyXGBoostforDirection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import applyXGBoostforDirection as mod
import numpy as np
import pandas as pd


def test_parse_image_selection_indices():
indices = mod.parse_image_selection("1,2,3")
assert indices == [1, 2, 3]


def test_parse_image_selection_bits():
indices = mod.parse_image_selection("14") # 0b1110 -> [1,2,3]
assert indices == [1, 2, 3]


def test_filter_by_telescope_selection_returns_mask_and_preserves_length():
df = pd.DataFrame(
{
"DispTelList_T": [
[1, 2, 3], # has 1,2,3
[1, 3], # missing 2
[0, 1, 2, 3], # 4 telescopes -> always included
[0, 2], # missing 1,3
]
}
)
selected = [1, 2, 3]
mask = mod.filter_by_telescope_selection(df, selected)
assert isinstance(mask, pd.Series)
assert len(mask) == len(df)
# Expect True for rows 0 and 2
assert mask.tolist() == [True, False, True, False]


class DummyModel:
def __init__(self, out_val=(0.0, 0.0)):
self.out_val = np.array(out_val, dtype=float)

def predict(self, X):
# Return shape (n_rows, 2) filled with out_val
n = len(X)
return np.tile(self.out_val, (n, 1))


def test_apply_models_with_selection_mask(monkeypatch):
# Build a minimal DataFrame with required columns
df = pd.DataFrame(
{
"DispNImages": [3, 3, 4, 2, 3],
"DispTelList_T": [
[1, 2, 3], # selected
[1, 3], # unselected
[0, 1, 2, 3], # 4 telescopes -> selected
[0, 1], # unselected
[1, 2, 3], # selected
],
# Truth-related columns used downstream; values not relevant for this test
"Xoff": [0, 0, 0, 0, 0],
"Yoff": [0, 0, 0, 0, 0],
"Xoff_intersect": [0, 0, 0, 0, 0],
"Yoff_intersect": [0, 0, 0, 0, 0],
}
)

# Selection: require telescopes 1,2,3 or len==4
selection_mask = mod.filter_by_telescope_selection(df, [1, 2, 3])

# Monkeypatch model loading and existence checks
monkeypatch.setattr(mod.os.path, "exists", lambda p: True)

# Always return a DummyModel; value differentiates by n_tel if desired
def fake_load(path):
if "ntel4" in path:
return DummyModel(out_val=(4.0, -4.0))
elif "ntel3" in path:
return DummyModel(out_val=(3.0, -3.0))
else:
return DummyModel(out_val=(2.0, -2.0))

monkeypatch.setattr(mod.joblib, "load", fake_load)

# Monkeypatch flattening to avoid complex array inputs
def fake_flatten(group_df, n_tel, training_vars):
# Provide a simple feature column not in excluded list
return pd.DataFrame({"feature": np.zeros(len(group_df))}, index=group_df.index)

monkeypatch.setattr(mod, "flatten_data_vectorized", fake_flatten)

pred_x, pred_y = mod.apply_models(df, "dummy_models_dir", selection_mask)

# Output length must match input length
assert len(pred_x) == len(df)
assert len(pred_y) == len(df)

# Unselected rows must be -999; selected rows come from model values
# From selection_mask above: rows 0,2,4 are selected; 1,3 are unselected
expected_x = [3.0, -999.0, 4.0, -999.0, 3.0]
expected_y = [-3.0, -999.0, -4.0, -999.0, -3.0]
assert np.allclose(pred_x, expected_x, equal_nan=False)
assert np.allclose(pred_y, expected_y, equal_nan=False)
27 changes: 3 additions & 24 deletions python/trainXGBoostforDirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,11 @@
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.multioutput import MultiOutputRegressor
from training_variables import xgb_training_variables

logging.basicConfig(level=logging.INFO)
_logger = logging.getLogger("trainXGBoostforDirection")

# Telescope-type training variables
# Disp variables with different indexing logic in data preparation
TRAINING_VARIABLES = [
"Disp_T",
"DispXoff_T",
"DispYoff_T",
"DispWoff_T",
"cen_x",
"cen_y",
"cosphi",
"sinphi",
"loss",
"size",
"dist",
"width",
"length",
"asym",
"tgrad_x",
"R_core",
]
N_TEL_VAR = len(TRAINING_VARIABLES)


def load_and_flatten_data(input_files, n_tel, max_events, training_step=True):
"""
Expand All @@ -69,7 +48,7 @@ def load_and_flatten_data(input_files, n_tel, max_events, training_step=True):
"MCxoff",
"MCyoff",
"MCe0",
] + [var for var in TRAINING_VARIABLES]
] + [var for var in xgb_training_variables()]

dfs = []
if max_events > 0:
Expand Down Expand Up @@ -114,7 +93,7 @@ def load_and_flatten_data(input_files, n_tel, max_events, training_step=True):
) # * data_tree["MCe0"]
)

df_flat = flatten_data_vectorized(data_tree, n_tel, TRAINING_VARIABLES)
df_flat = flatten_data_vectorized(data_tree, n_tel, xgb_training_variables())

if training_step:
df_flat["MCxoff"] = data_tree["MCxoff"]
Expand Down
Loading
Loading