@@ -460,7 +460,7 @@ def flatten_telescope_data_vectorized(
460460 flat_features [f"{ var } _{ tel_idx } " ] = data_normalized [:, tel_idx ]
461461
462462 index = _get_index (df , n_evt )
463- df_flat = flatten_telescope_variables (n_tel , flat_features , index , tel_config )
463+ df_flat = flatten_telescope_variables (n_tel , flat_features , index , tel_config , analysis_type )
464464 return pd .concat (
465465 [df_flat , extra_columns (df , analysis_type , training , index , tel_config , observatory )],
466466 axis = 1 ,
@@ -814,7 +814,7 @@ def apply_clip_intervals(df, n_tel=None, apply_log10=None):
814814 df .loc [mask_to_log , var_base ] = np .log10 (df .loc [mask_to_log , var_base ])
815815
816816
817- def flatten_telescope_variables (n_tel , flat_features , index , tel_config = None ):
817+ def flatten_telescope_variables (n_tel , flat_features , index , tel_config = None , analysis_type = None ):
818818 """Generate dataframe for telescope variables flattened for all telescopes.
819819
820820 Creates features for all telescope IDs, using NaN as default value for missing data.
@@ -829,13 +829,19 @@ def flatten_telescope_variables(n_tel, flat_features, index, tel_config=None):
829829 DataFrame index.
830830 tel_config : dict, optional
831831 Telescope configuration with 'max_tel_id' key.
832+ analysis_type : str, optional
833+ Type of analysis, e.g. "classification" or "stereo_analysis".
832834 """
833835 df_flat = pd .DataFrame (flat_features , index = index )
834836 df_flat = df_flat .astype (np .float32 )
835837
836838 # Determine max telescope ID from config or use n_tel
837839 max_tel_id = tel_config ["max_tel_id" ] if tel_config else (n_tel - 1 )
838840
841+ keep_size_vars = analysis_type == "stereo_analysis"
842+ if not keep_size_vars :
843+ _logger .info (f"Dropping 'size'-related variables for { analysis_type } analysis." )
844+
839845 new_cols = {}
840846 for i in range (max_tel_id + 1 ): # Iterate over all possible telescopes
841847 if f"Disp_T_{ i } " in df_flat :
@@ -844,7 +850,7 @@ def flatten_telescope_variables(n_tel, flat_features, index, tel_config=None):
844850 if f"loss_{ i } " in df_flat and f"dist_{ i } " in df_flat :
845851 new_cols [f"loss_loss_{ i } " ] = df_flat [f"loss_{ i } " ] ** 2
846852 new_cols [f"loss_dist_{ i } " ] = df_flat [f"loss_{ i } " ] * df_flat [f"dist_{ i } " ]
847- if f"size_{ i } " in df_flat and f"dist_{ i } " in df_flat :
853+ if f"size_{ i } " in df_flat and f"dist_{ i } " in df_flat and keep_size_vars :
848854 new_cols [f"size_dist2_{ i } " ] = df_flat [f"size_{ i } " ] / (df_flat [f"dist_{ i } " ] ** 2 + 1e-6 )
849855 if f"width_{ i } " in df_flat and f"length_{ i } " in df_flat :
850856 new_cols [f"width_length_{ i } " ] = df_flat [f"width_{ i } " ] / (df_flat [f"length_{ i } " ] + 1e-6 )
@@ -873,6 +879,8 @@ def flatten_telescope_variables(n_tel, flat_features, index, tel_config=None):
873879 if f"cen_y_{ i } " in df_flat and f"fpointing_dy_{ i } " in df_flat :
874880 df_flat [f"cen_y_{ i } " ] = df_flat [f"cen_y_{ i } " ] + df_flat [f"fpointing_dy_{ i } " ]
875881 df_flat = df_flat .drop (columns = [f"fpointing_dx_{ i } " , f"fpointing_dy_{ i } " ], errors = "ignore" )
882+ if not keep_size_vars :
883+ df_flat = df_flat .drop (columns = [f"size_{ i } " ], errors = "ignore" )
876884
877885 return df_flat
878886
0 commit comments