Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/eventdisplay_ml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ def configure_training(analysis_type):
parser.add_argument(
"--model_prefix",
required=True,
help=("Path to directory for writing XGBoost models (without n_tel / energy bin suffix)."),
help=("Path to directory for writing XGBoost models (energy bin suffix)."),
Comment thread
GernotMaier marked this conversation as resolved.
Outdated
)
parser.add_argument(
"--hyperparameter_config",
help="Path to JSON file with hyperparameter configuration.",
default=None,
type=str,
)
parser.add_argument("--n_tel", type=int, help="Telescope multiplicity (2, 3, or 4).")
parser.add_argument(
"--train_test_fraction",
type=float,
Expand Down Expand Up @@ -111,7 +110,6 @@ def configure_training(analysis_type):

_logger.info(f"--- XGBoost {analysis_type} training ---")
_logger.info(f"Observatory: {model_configs.get('observatory')}")
_logger.info(f"Telescope multiplicity: {model_configs.get('n_tel')}")
_logger.info(f"Model output prefix: {model_configs.get('model_prefix')}")
_logger.info(f"Train vs test fraction: {model_configs['train_test_fraction']}")
_logger.info(f"Random state: {model_configs['random_state']}")
Expand All @@ -131,15 +129,14 @@ def configure_training(analysis_type):

if analysis_type == "stereo_analysis":
model_configs["pre_cuts"] = pre_cuts_regression(
model_configs.get("n_tel"), min_images=model_configs.get("min_images", 2)
min_images=model_configs.get("min_images", 2)
)
elif analysis_type == "classification":
_logger.info(f"Energy bin {model_configs['energy_bin_number']}")
model_parameters = utils.load_model_parameters(
model_configs["model_parameters"], model_configs["energy_bin_number"]
)
model_configs["pre_cuts"] = pre_cuts_classification(
model_configs.get("n_tel"),
e_min=np.power(10.0, model_parameters.get("energy_bins_log10_tev", []).get("E_min")),
e_max=np.power(10.0, model_parameters.get("energy_bins_log10_tev", []).get("E_max")),
)
Expand All @@ -165,7 +162,7 @@ def configure_apply(analysis_type):
"--model_prefix",
required=True,
metavar="MODEL_PREFIX",
help=("Path to directory containing XGBoost models (without n_tel / energy bin suffix)."),
help=("Path to directory containing XGBoost models (without energy bin suffix)."),
)
parser.add_argument(
"--model_name",
Expand Down
6 changes: 2 additions & 4 deletions src/eventdisplay_ml/hyper_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,12 @@ def _load_hyper_parameters_from_file(config_file):
return hyperparameters


def pre_cuts_regression(n_tel, min_images=2):
def pre_cuts_regression(min_images=2):
"""
Get pre-cuts for regression analysis.

Parameters
----------
n_tel : int or None
Number of telescopes (not currently used).
min_images : int
Minimum number of images (DispNImages) for quality cut (default: 2).

Expand All @@ -112,7 +110,7 @@ def pre_cuts_regression(n_tel, min_images=2):
return event_cut if event_cut else None


def pre_cuts_classification(n_tel, e_min, e_max):
def pre_cuts_classification(e_min, e_max):
"""Get pre-cuts for classification analysis (no multiplicity filter)."""
event_cut = f"(Erec >= {e_min}) & (Erec < {e_max})"
event_cut += " & " + " & ".join(f"({c})" for c in PRE_CUTS_CLASSIFICATION)
Expand Down
Loading