Skip to content

Commit ece4ae1

Browse files
authored
Merge pull request #45 from Eventdisplay/classifier-with-sorted-size
Classifer / VTS.
2 parents 79522c7 + 731c50b commit ece4ae1

File tree

6 files changed

+20
-9
lines changed

6 files changed

+20
-9
lines changed

docs/changes/45.maintenance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update g/h separation to new sorting scheme of telescope-dependent variables.

src/eventdisplay_ml/data_processing.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def flatten_telescope_data_vectorized(
460460
flat_features[f"{var}_{tel_idx}"] = data_normalized[:, tel_idx]
461461

462462
index = _get_index(df, n_evt)
463-
df_flat = flatten_telescope_variables(n_tel, flat_features, index, tel_config)
463+
df_flat = flatten_telescope_variables(n_tel, flat_features, index, tel_config, analysis_type)
464464
return pd.concat(
465465
[df_flat, extra_columns(df, analysis_type, training, index, tel_config, observatory)],
466466
axis=1,
@@ -814,7 +814,7 @@ def apply_clip_intervals(df, n_tel=None, apply_log10=None):
814814
df.loc[mask_to_log, var_base] = np.log10(df.loc[mask_to_log, var_base])
815815

816816

817-
def flatten_telescope_variables(n_tel, flat_features, index, tel_config=None):
817+
def flatten_telescope_variables(n_tel, flat_features, index, tel_config=None, analysis_type=None):
818818
"""Generate dataframe for telescope variables flattened for all telescopes.
819819
820820
Creates features for all telescope IDs, using NaN as default value for missing data.
@@ -829,13 +829,19 @@ def flatten_telescope_variables(n_tel, flat_features, index, tel_config=None):
829829
DataFrame index.
830830
tel_config : dict, optional
831831
Telescope configuration with 'max_tel_id' key.
832+
analysis_type : str, optional
833+
Type of analysis, e.g. "classification" or "stereo_analysis".
832834
"""
833835
df_flat = pd.DataFrame(flat_features, index=index)
834836
df_flat = df_flat.astype(np.float32)
835837

836838
# Determine max telescope ID from config or use n_tel
837839
max_tel_id = tel_config["max_tel_id"] if tel_config else (n_tel - 1)
838840

841+
keep_size_vars = analysis_type == "stereo_analysis"
842+
if not keep_size_vars:
843+
_logger.info(f"Dropping 'size'-related variables for {analysis_type} analysis.")
844+
839845
new_cols = {}
840846
for i in range(max_tel_id + 1): # Iterate over all possible telescopes
841847
if f"Disp_T_{i}" in df_flat:
@@ -844,7 +850,7 @@ def flatten_telescope_variables(n_tel, flat_features, index, tel_config=None):
844850
if f"loss_{i}" in df_flat and f"dist_{i}" in df_flat:
845851
new_cols[f"loss_loss_{i}"] = df_flat[f"loss_{i}"] ** 2
846852
new_cols[f"loss_dist_{i}"] = df_flat[f"loss_{i}"] * df_flat[f"dist_{i}"]
847-
if f"size_{i}" in df_flat and f"dist_{i}" in df_flat:
853+
if f"size_{i}" in df_flat and f"dist_{i}" in df_flat and keep_size_vars:
848854
new_cols[f"size_dist2_{i}"] = df_flat[f"size_{i}"] / (df_flat[f"dist_{i}"] ** 2 + 1e-6)
849855
if f"width_{i}" in df_flat and f"length_{i}" in df_flat:
850856
new_cols[f"width_length_{i}"] = df_flat[f"width_{i}"] / (df_flat[f"length_{i}"] + 1e-6)
@@ -873,6 +879,8 @@ def flatten_telescope_variables(n_tel, flat_features, index, tel_config=None):
873879
if f"cen_y_{i}" in df_flat and f"fpointing_dy_{i}" in df_flat:
874880
df_flat[f"cen_y_{i}"] = df_flat[f"cen_y_{i}"] + df_flat[f"fpointing_dy_{i}"]
875881
df_flat = df_flat.drop(columns=[f"fpointing_dx_{i}", f"fpointing_dy_{i}"], errors="ignore")
882+
if not keep_size_vars:
883+
df_flat = df_flat.drop(columns=[f"size_{i}"], errors="ignore")
876884

877885
return df_flat
878886

src/eventdisplay_ml/features.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def telescope_features(analysis_type):
7474
List of telescope-level feature names.
7575
"""
7676
var = [
77+
"size",
7778
"cosphi",
7879
"sinphi",
7980
"loss",
@@ -95,7 +96,6 @@ def telescope_features(analysis_type):
9596

9697
return [
9798
*var,
98-
"size",
9999
"cen_x",
100100
"cen_y",
101101
"E",
@@ -147,9 +147,12 @@ def _classification_features():
147147
"MSCL",
148148
"ArrayPointing_Elevation",
149149
"ArrayPointing_Azimuth",
150+
"Xcore",
151+
"Ycore",
150152
]
151153
# energy used to bin the models, but not as feature
152-
return var_tel + var_array + ["Erec"]
154+
# size used for sorting events during flattening, but not as feature
155+
return var_tel + var_array + ["Erec", "size"]
153156

154157

155158
def clip_intervals():

src/eventdisplay_ml/geomag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"CTAO-SOUTH": {
2020
"BX": 20.552e-6, # Tesla
2121
"BY": 0.0, # Tesla
22-
"BZ": -9.367 - 6, # Tesla
22+
"BZ": -9.367e-6, # Tesla
2323
},
2424
}
2525

src/eventdisplay_ml/models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,7 @@ def process_file_chunked(analysis_type, model_configs):
360360
threshold_keys = sorted(
361361
{
362362
eff
363-
for n_tel_models in model_configs["models"].values()
364-
for e_bin_models in n_tel_models.values()
363+
for e_bin_models in model_configs["models"].values()
365364
for eff in (e_bin_models.get("thresholds") or {}).keys()
366365
}
367366
)

src/eventdisplay_ml/scripts/train_xgb_classify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Uses image and stereo parameters to train classification BDTs to separate
55
gamma-ray events from hadronic background events.
66
7-
Separate BDTs are trained for 2, 3, and 4 telescope multiplicity events.
7+
Trains a single classifier on all telescope multiplicity events.
88
"""
99

1010
import logging

0 commit comments

Comments
 (0)