diff --git a/src/truckscenes/truckscenes.py b/src/truckscenes/truckscenes.py index 5f44e80..6d0c3d5 100644 --- a/src/truckscenes/truckscenes.py +++ b/src/truckscenes/truckscenes.py @@ -50,7 +50,8 @@ def __init__(self, self.verbose = verbose self.table_names = ['attribute', 'calibrated_sensor', 'category', 'ego_motion_cabin', 'ego_motion_chassis', 'ego_pose', 'instance', 'sample', - 'sample_annotation', 'sample_data', 'scene', 'sensor', 'visibility'] + 'sample_annotation', 'sample_data', 'scene', 'sensor', 'visibility', + 'weather_annotation'] assert osp.exists(self.table_root), \ f'Database version not found: {self.table_root}' @@ -73,6 +74,7 @@ def __init__(self, self.scene = self.__load_table__('scene') self.sensor = self.__load_table__('sensor') self.visibility = self.__load_table__('visibility') + self.weather_annotation = self.__load_table__('weather_annotation') # Initialize the colormap which maps from class names to RGB values. self.colormap = colormap.get_colormap() @@ -493,6 +495,14 @@ def list_scenes(self) -> None: def list_sample(self, sample_token: str) -> None: self.explorer.list_sample(sample_token) + def get_scenes_weather_annotations_filtered(self, conditions: List[Tuple[str, str, float]]) -> List[str]: + if 'weather_annotation' not in self.table_names: + return [] + return self.explorer.get_scenes_weather_annotations_filtered(conditions) + + def get_scenes_description_filtered(self, conditions: List[Tuple[str, str]]) -> List[str]: + return self.explorer.get_scenes_description_filtered(conditions) + def render_pointcloud_in_image(self, sample_token: str, dot_size: int = 5, pointsensor_channel: str = 'LIDAR_LEFT', camera_channel: str = 'CAMERA_LEFT_FRONT', diff --git a/src/truckscenes/utils/visualization_utils.py b/src/truckscenes/utils/visualization_utils.py index e2c59a0..7e4f35b 100644 --- a/src/truckscenes/utils/visualization_utils.py +++ b/src/truckscenes/utils/visualization_utils.py @@ -126,6 +126,118 @@ def list_sample(self, sample_token: str) -> None: print(f"sample_annotation_token: {ann_record['token']}" f", category: {ann_record['category_name']}") + def get_scenes_weather_annotations_filtered(self, conditions: List[Tuple[str, str, float]]) -> List[str]: + """ + Filters scenes based on annotated weather conditions and returns matching scene tokens. + + Arguments: + conditions: List of tuples (field, operator, value) where: + - field: Weather field name (any field from weather_annotation except 'token') + - operator: Comparison operator ('<', '>', '<=', '>=', '==', '!=') + - value: Numeric threshold value + + Returns: + List of scene tokens that match all conditions. + + Example: + # Find scenes with wind between 2 and 5, and rain > 1.0 + conditions = [ + ('wind', '>=', 2.0), + ('wind', '<=', 5.0), + ('rain', '>', 1.0) + ] + scene_tokens = ts.get_scenes_weather_annotations_filtered(conditions) + """ + # Check if weather_annotation exists and has data + if not hasattr(self.trucksc, 'weather_annotation') or not self.trucksc.weather_annotation: + return [] + + # Get valid fields from weather_annotation data (excluding 'token') + valid_fields = set(self.trucksc.weather_annotation[0].keys()) - {'token'} + valid_operators = {'<', '>', '<=', '>=', '==', '!='} + + # Validate all conditions + for field, operator, value in conditions: + assert field in valid_fields, f"Invalid field '{field}'. Valid fields: {valid_fields}" + assert operator in valid_operators, f"Invalid operator '{operator}'. Valid operators: {valid_operators}" + assert isinstance(value, (int, float)), f"Value must be numeric, got {type(value)}" + + def _evaluate_condition(field_value: float, operator: str, threshold: float) -> bool: + """Evaluate a single condition.""" + if operator == '<': + return field_value < threshold + elif operator == '>': + return field_value > threshold + elif operator == '<=': + return field_value <= threshold + elif operator == '>=': + return field_value >= threshold + elif operator == '==': + return abs(field_value - threshold) < 1e-6 # Float equality with tolerance + elif operator == '!=': + return abs(field_value - threshold) >= 1e-6 # Float inequality with tolerance + return False + + matching_scenes = [] + + for scene_record in self.trucksc.scene: + weather_token = scene_record['weather_annotation_token'] + weather_record = self.trucksc.get('weather_annotation', weather_token) + + # Check if scene matches all conditions + matches_all = True + for field, operator, value in conditions: + field_value = weather_record[field] + if not _evaluate_condition(field_value, operator, value): + matches_all = False + break + + if matches_all: + matching_scenes.append(scene_record['token']) + + return matching_scenes + + def get_scenes_description_filtered(self, conditions: List[Tuple[str, str]]) -> List[str]: + """ + Filters scenes based on description keywords and returns matching scene tokens. + + Arguments: + conditions: List of tuples (keyword, operator) where: + - keyword: String to search for in scene description + - operator: Comparison operator ('==' for contains, '!=' for not contains) + + Returns: + List of scene tokens that match all conditions. + """ + valid_operators = {'==', '!='} + + # Validate all conditions + for keyword, operator in conditions: + assert isinstance(keyword, str), f"Keyword must be string, got {type(keyword)}" + assert operator in valid_operators, f"Invalid operator '{operator}'. Valid operators: {valid_operators}" + + matching_scenes = [] + + for scene_record in self.trucksc.scene: + description = scene_record['description'] + + # Check if scene matches all conditions + matches_all = True + for keyword, operator in conditions: + if operator == '==': + if keyword not in description: + matches_all = False + break + elif operator == '!=': + if keyword in description: + matches_all = False + break + + if matches_all: + matching_scenes.append(scene_record['token']) + + return matching_scenes + def map_pointcloud_to_image(self, pointsensor_token: str, camera_token: str,