@@ -405,6 +405,27 @@ def __init__(
405405 f"/dl1/event/telescope/parameters/tel_{ self .tel_ids [0 ]:03d} " ,
406406 ).colnames
407407
408+ # Columns to keep in the the example identifiers
409+ # This are the basic columns one need to do a
410+ # conventional IACT analysis with CNNs
411+ self .example_ids_keep_columns = ["table_index" , "obs_id" , "event_id" , "tel_id" ]
412+ if self .process_type == ProcessType .Simulation :
413+ self .example_ids_keep_columns .extend (
414+ [
415+ "true_energy" ,
416+ "true_shower_primary_id" ,
417+ "true_az" ,
418+ "telescope_pointing_azimuth" ,
419+ "true_alt" ,
420+ "telescope_pointing_altitude" ,
421+ "cam_coord_offset_x" ,
422+ "cam_coord_offset_y" ,
423+ "cam_coord_distance" ,
424+ ]
425+ )
426+ elif self .process_type == ProcessType .Observation :
427+ self .example_ids_keep_columns .extend (["time" , "event_type" ])
428+
408429 # Construct the example identifiers
409430 if self .mode == "mono" :
410431 self ._construct_mono_example_identifiers ()
@@ -442,15 +463,6 @@ def _construct_mono_example_identifiers(self):
442463 and constructs identifiers based on the event and telescope IDs. These
443464 identifiers are used to uniquely reference each example in the dataset.
444465 """
445- # Columns to keep in the the example identifiers
446- # This are the basic columns one need to do a
447- # conventional IACT analysis with CNNs
448- self .example_ids_keep_columns = ["table_index" , "obs_id" , "event_id" , "tel_id" ]
449- if self .process_type == ProcessType .Simulation :
450- self .example_ids_keep_columns .extend (
451- ["true_energy" , "true_alt" , "true_az" , "true_shower_primary_id" ]
452- )
453-
454466 simulation_info = []
455467 example_identifiers = []
456468 for file_idx , (filename , f ) in enumerate (self .files .items ()):
@@ -477,6 +489,14 @@ def _construct_mono_example_identifiers(self):
477489 right = simshower_table ,
478490 keys = ["obs_id" , "event_id" ],
479491 )
492+ # Add the spherical offsets w.r.t. to the telescope pointing
493+ tel_pointing = self .get_tel_pointing (f , [tel_id ])
494+ tel_table = join (
495+ left = tel_table ,
496+ right = tel_pointing ,
497+ keys = ["obs_id" , "tel_id" ],
498+ )
499+ tel_table = self ._transform_to_cam_coord_offsets (tel_table )
480500 tel_tables .append (tel_table )
481501 events = vstack (tel_tables )
482502
@@ -495,13 +515,6 @@ def _construct_mono_example_identifiers(self):
495515 events .keep_columns (self .example_ids_keep_columns )
496516 if self .process_type == ProcessType .Simulation :
497517 # Add the spherical offsets w.r.t. to the telescope pointing
498- tel_pointing = self .get_tel_pointing (f , self .tel_ids )
499- events = join (
500- left = events ,
501- right = tel_pointing ,
502- keys = ["obs_id" , "tel_id" ],
503- )
504- events = self ._transform_to_cam_coord_offsets (events )
505518 array_pointing = self .get_array_pointing (f )
506519 # Join the prediction table with the telescope pointing table
507520 events = join (
@@ -560,20 +573,7 @@ def _construct_stereo_example_identifiers(self):
560573 # Columns to keep in the the example identifiers
561574 # This are the basic columns one need to do a
562575 # conventional IACT analysis with CNNs
563- self .example_ids_keep_columns = [
564- "table_index" ,
565- "obs_id" ,
566- "event_id" ,
567- "tel_id" ,
568- "hillas_intensity" ,
569- ]
570- if self .process_type == ProcessType .Simulation :
571- self .example_ids_keep_columns .extend (
572- ["true_energy" , "true_alt" , "true_az" , "true_shower_primary_id" ]
573- )
574- elif self .process_type == ProcessType .Observation :
575- self .example_ids_keep_columns .extend (["time" , "event_type" ])
576-
576+ self .example_ids_keep_columns .extend (["hillas_intensity" ])
577577 simulation_info = []
578578 example_identifiers = []
579579 for file_idx , (filename , f ) in enumerate (self .files .items ()):
@@ -617,21 +617,19 @@ def _construct_stereo_example_identifiers(self):
617617 right = trigger_table ,
618618 keys = ["obs_id" , "event_id" ],
619619 )
620+ if self .process_type == ProcessType .Simulation :
621+ tel_pointing = self .get_tel_pointing (f , [tel_id ])
622+ merged_table = join (
623+ left = merged_table ,
624+ right = tel_pointing ,
625+ keys = ["obs_id" , "tel_id" ],
626+ )
627+ merged_table = self ._transform_to_cam_coord_offsets (
628+ merged_table
629+ )
620630 table_per_type .append (merged_table )
621631 table_per_type = vstack (table_per_type )
622-
623- table_per_type = table_per_type .group_by (["obs_id" , "event_id" ])
624632 table_per_type .keep_columns (self .example_ids_keep_columns )
625- if self .process_type == ProcessType .Simulation :
626- tel_pointing = self .get_tel_pointing (f , self .tel_ids )
627- table_per_type = join (
628- left = table_per_type ,
629- right = tel_pointing ,
630- keys = ["obs_id" , "tel_id" ],
631- )
632- table_per_type = self ._transform_to_cam_coord_offsets (
633- table_per_type
634- )
635633 # Apply the multiplicity cut based on the telescope type
636634 table_per_type = table_per_type .group_by (["obs_id" , "event_id" ])
637635
@@ -793,48 +791,38 @@ def _transform_to_cam_coord_offsets(self, table) -> Table:
793791 table : astropy.table.Table
794792 A Table with the spherical offsets and the angular separation added as new columns.
795793 """
796-
797- tel_tables = []
798- for tel_id in self .tel_ids :
799- tel_table = table .copy ()
800- tel_table = tel_table [tel_table ["tel_id" ] == tel_id ]
801- # Set the telescope pointing of the SkyOffsetSeparation tranform to the fix pointing
802- tel_ground_frame = self .subarray .tel_coords [
803- self .subarray .tel_ids_to_indices (tel_id )
804- ]
805- fix_tel_pointing = SkyCoord (
806- tel_table ["telescope_pointing_azimuth" ],
807- tel_table ["telescope_pointing_altitude" ],
808- location = tel_ground_frame .to_earth_location (),
809- obstime = LST_EPOCH ,
810- )
811- camera_frame = CameraFrame (
812- focal_length = self .subarray .tel [tel_id ].optics .equivalent_focal_length ,
813- rotation = self .subarray .tel [tel_id ].camera .geometry .pix_rotation ,
814- telescope_pointing = fix_tel_pointing ,
815- )
816- # Transform the true Alt/Az coordinates to camera coordinates
817- true_direction = SkyCoord (
818- tel_table ["true_az" ],
819- tel_table ["true_alt" ],
820- location = tel_ground_frame .to_earth_location (),
821- obstime = LST_EPOCH ,
822- )
823- true_cam_position = true_direction .transform_to (camera_frame )
824- true_cam_distance = np .sqrt (
825- true_cam_position .x ** 2 + true_cam_position .y ** 2
826- )
827- tel_table .keep_columns (["obs_id" , "tel_id" ])
828- tel_table .add_column (true_cam_position .x , name = "cam_coord_offset_x" )
829- tel_table .add_column (true_cam_position .y , name = "cam_coord_offset_y" )
830- tel_table .add_column (true_cam_distance , name = "cam_coord_distance" )
831- tel_tables .append (tel_table )
832- tel_tables = vstack (tel_tables )
833- table = join (
834- left = table ,
835- right = tel_tables ,
836- keys = ["obs_id" , "tel_id" ],
794+ # Get the telescope ID from the table
795+ tel_id = table ["tel_id" ][0 ]
796+ # Set the telescope pointing of the SkyOffsetSeparation tranform to the fix pointing
797+ tel_ground_frame = self .subarray .tel_coords [
798+ self .subarray .tel_ids_to_indices (tel_id )
799+ ]
800+ fix_tel_pointing = SkyCoord (
801+ table ["telescope_pointing_azimuth" ],
802+ table ["telescope_pointing_altitude" ],
803+ location = tel_ground_frame .to_earth_location (),
804+ obstime = LST_EPOCH ,
805+ )
806+ # Set the camera frame with the focal length and rotation of the camera
807+ camera_frame = CameraFrame (
808+ focal_length = self .subarray .tel [tel_id ].optics .equivalent_focal_length ,
809+ rotation = self .subarray .tel [tel_id ].camera .geometry .pix_rotation ,
810+ telescope_pointing = fix_tel_pointing ,
837811 )
812+ # Transform the true Alt/Az coordinates to camera coordinates
813+ true_direction = SkyCoord (
814+ table ["true_az" ],
815+ table ["true_alt" ],
816+ location = tel_ground_frame .to_earth_location (),
817+ obstime = LST_EPOCH ,
818+ )
819+ # Calculate the camera coordinate offsets and distance
820+ true_cam_position = true_direction .transform_to (camera_frame )
821+ true_cam_distance = np .sqrt (true_cam_position .x ** 2 + true_cam_position .y ** 2 )
822+ # Add the camera coordinate offsets and distance to the table
823+ table .add_column (true_cam_position .x , name = "cam_coord_offset_x" )
824+ table .add_column (true_cam_position .y , name = "cam_coord_offset_y" )
825+ table .add_column (true_cam_distance , name = "cam_coord_distance" )
838826 return table
839827
840828 def _transform_to_sky_spher_offsets (self , table ) -> Table :
@@ -1016,13 +1004,13 @@ def generate_stereo_batch(self, batch_indices) -> Table:
10161004 if "features" in group_element .colnames :
10171005 blank_input_row ["features" ] = blank_input
10181006 if "mono_feature_vectors" in group_element .colnames :
1019- blank_input_row [
1020- "mono_feature_vectors"
1021- ] = blank_mono_feature_vectors
1007+ blank_input_row ["mono_feature_vectors" ] = (
1008+ blank_mono_feature_vectors
1009+ )
10221010 if "stereo_feature_vectors" in group_element .colnames :
1023- blank_input_row [
1024- "stereo_feature_vectors"
1025- ] = blank_stereo_feature_vectors
1011+ blank_input_row ["stereo_feature_vectors" ] = (
1012+ blank_stereo_feature_vectors
1013+ )
10261014 batch .add_row (blank_input_row )
10271015 # Sort the batch with the new rows of blank inputs
10281016 batch .sort (["obs_id" , "event_id" , "tel_type_id" , "tel_id" ])
0 commit comments