2121import json
2222import os
2323import pathlib
24- from typing import Any , Iterator , List , Mapping , Optional , Tuple , Union
24+ from typing import Any , Dict , Iterator , List , Mapping , Optional , Tuple , Union
2525
2626import numpy as np
2727
4848IMAGE_MASK_FILENAME_EXTENSION = '.png'
4949DEPTH_FILENAME_EXTENSION = '.npz'
5050CAMERA_POSES_CSV_FILENAME = 'camera_poses.csv'
51+ FRAME_SEGMENTATION_ANNOTATION_TYPE_FILENAME = (
52+ 'frame_segmentation_annotation_type.json'
53+ )
5154
5255
5356FEATURE_SESSION_TYPE = 'session_type'
5659FEATURE_IMAGE_RIGHT = 'image_right'
5760FEATURE_METRIC_DEPTH_LABEL = 'metric_depth'
5861FEATURE_HAS_METRIC_DEPTH_LABEL = 'has_metric_depth'
62+ FEATURE_SEGMENTATION_ANNOTATION_TYPE = 'segmentation_annotation_type'
5963FEATURE_PANOPTIC_MASK_LABEL = 'panoptic_label'
6064FEATURE_HAS_PANOPTIC_MASK_LABEL = 'has_panoptic_label'
6165FEATURE_SEMANTIC_LABEL = 'semantic_label'
@@ -102,7 +106,6 @@ def wrapped_path(path: Any) -> Any:
102106
103107
104108class DatasetViewMode (enum .Enum ):
105-
106109 """Configures which data to include in each sample.
107110
108111 STEREO_VIEW_FRAME_MODE: Stereo view in frame mode. Each sample contains both
@@ -433,10 +436,42 @@ def has_segmentation_annotation(self) -> bool:
433436 return True
434437 return False
435438
436- def lens_names (self , sensor_name : str ) -> List [str ]:
439+ def segmentation_annotation_type (
440+ self , input_sensor_name : str
441+ ) -> Optional [Dict [int , str ]]:
442+ """Returns segmentation annotation type for the given sensor."""
443+ # Computed once.
444+ if not hasattr (self , '_sensor_frame_segmentation_annotation_type' ):
445+ self ._sensor_frame_segmentation_annotation_type = {}
446+ for sensor_name in self .sensor_names :
447+ frame_segmentation_annotation_filepath = (
448+ self .base_path
449+ / sensor_name
450+ / LEFT_LENS_NAME
451+ / FRAME_SEGMENTATION_ANNOTATION_TYPE_FILENAME
452+ )
453+ if frame_segmentation_annotation_filepath .exists ():
454+ with wrapped_open (
455+ frame_segmentation_annotation_filepath , 'r'
456+ ) as fileptr :
457+ frame_segmentation_annotation_type = json .load (fileptr )
458+ self ._sensor_frame_segmentation_annotation_type [sensor_name ] = {}
459+ for (
460+ frame_num_str ,
461+ annotation_type ,
462+ ) in frame_segmentation_annotation_type .items ():
463+ self ._sensor_frame_segmentation_annotation_type [sensor_name ][
464+ int (frame_num_str )
465+ ] = annotation_type
466+ else :
467+ self ._sensor_frame_segmentation_annotation_type [sensor_name ] = None
468+
469+ return self ._sensor_frame_segmentation_annotation_type [input_sensor_name ]
470+
471+ def lens_names (self , input_sensor_name : str ) -> List [str ]:
437472 """Returns lens names in the session's sensor."""
438473
439- # Just compute once.
474+ # Computed once.
440475 if not hasattr (self , '_sensor_lens_names' ):
441476 self ._sensor_lens_names = {}
442477 for sensor_name in self .sensor_names :
@@ -448,10 +483,10 @@ def lens_names(self, sensor_name: str) -> List[str]:
448483 lens_names .sort ()
449484 self ._sensor_lens_names [sensor_name ] = lens_names
450485
451- return self ._sensor_lens_names [sensor_name ]
486+ return self ._sensor_lens_names [input_sensor_name ]
452487
453488 def camera_poses (
454- self , sensor_name : str
489+ self , input_sensor_name : str
455490 ) -> List [Mapping [str , Union [bool , np .ndarray ]]]:
456491 """Returns camera poses corresponding the session's sensor."""
457492
@@ -465,7 +500,7 @@ def camera_poses(
465500 csv_file
466501 )
467502
468- return self ._sensor_camera_poses [sensor_name ]
503+ return self ._sensor_camera_poses [input_sensor_name ]
469504
470505 def camera_intrinsics (
471506 self , sensor_name : str , lens_name : str
@@ -657,6 +692,9 @@ def _itersamples(
657692 else ''
658693 )
659694 sample [FEATURE_HAS_PANOPTIC_MASK_LABEL ] = segmentation_mask_exists
695+ sample [FEATURE_SEGMENTATION_ANNOTATION_TYPE ] = (
696+ ex .segmentation_annotation_type
697+ )
660698
661699 if self .config .feature_metric_depth_zed .to_include ():
662700 metric_zed_depth_filename = ex .metric_depth_zed_filename (
@@ -714,6 +752,16 @@ def has_metric_depth_zed(self):
714752 def has_segmentation_mask (self ):
715753 return self .segmentation_mask_filename (LEFT_LENS_NAME ).exists ()
716754
755+ @functools .cached_property
756+ def segmentation_annotation_type (self ) -> str :
757+ segmentation_annotation_type = self .session .segmentation_annotation_type (
758+ self .sensor_name
759+ )
760+ if segmentation_annotation_type is None :
761+ return 'NA'
762+ else :
763+ return segmentation_annotation_type [self .frame_num ]
764+
717765 @property
718766 def has_camera_pose (self ):
719767 # all examples have some camera pose
0 commit comments