Skip to content

Commit 4f36b5a

Browse files
Sagar Waghmaresagarwaghmare69
authored andcommitted
Support frame annotation type feature.
PiperOrigin-RevId: 604811120
1 parent c84ac8c commit 4f36b5a

File tree

5 files changed

+127
-31
lines changed

5 files changed

+127
-31
lines changed

sanpo_dataset/lib/common.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import json
2222
import os
2323
import pathlib
24-
from typing import Any, Iterator, List, Mapping, Optional, Tuple, Union
24+
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union
2525

2626
import numpy as np
2727

@@ -48,6 +48,9 @@
4848
IMAGE_MASK_FILENAME_EXTENSION = '.png'
4949
DEPTH_FILENAME_EXTENSION = '.npz'
5050
CAMERA_POSES_CSV_FILENAME = 'camera_poses.csv'
51+
FRAME_SEGMENTATION_ANNOTATION_TYPE_FILENAME = (
52+
'frame_segmentation_annotation_type.json'
53+
)
5154

5255

5356
FEATURE_SESSION_TYPE = 'session_type'
@@ -56,6 +59,7 @@
5659
FEATURE_IMAGE_RIGHT = 'image_right'
5760
FEATURE_METRIC_DEPTH_LABEL = 'metric_depth'
5861
FEATURE_HAS_METRIC_DEPTH_LABEL = 'has_metric_depth'
62+
FEATURE_SEGMENTATION_ANNOTATION_TYPE = 'segmentation_annotation_type'
5963
FEATURE_PANOPTIC_MASK_LABEL = 'panoptic_label'
6064
FEATURE_HAS_PANOPTIC_MASK_LABEL = 'has_panoptic_label'
6165
FEATURE_SEMANTIC_LABEL = 'semantic_label'
@@ -102,7 +106,6 @@ def wrapped_path(path: Any) -> Any:
102106

103107

104108
class DatasetViewMode(enum.Enum):
105-
106109
"""Configures which data to include in each sample.
107110
108111
STEREO_VIEW_FRAME_MODE: Stereo view in frame mode. Each sample contains both
@@ -433,10 +436,42 @@ def has_segmentation_annotation(self) -> bool:
433436
return True
434437
return False
435438

436-
def lens_names(self, sensor_name: str) -> List[str]:
439+
def segmentation_annotation_type(
440+
self, input_sensor_name: str
441+
) -> Optional[Dict[int, str]]:
442+
"""Returns segmentation annotation type for the given sensor."""
443+
# Computed once.
444+
if not hasattr(self, '_sensor_frame_segmentation_annotation_type'):
445+
self._sensor_frame_segmentation_annotation_type = {}
446+
for sensor_name in self.sensor_names:
447+
frame_segmentation_annotation_filepath = (
448+
self.base_path
449+
/ sensor_name
450+
/ LEFT_LENS_NAME
451+
/ FRAME_SEGMENTATION_ANNOTATION_TYPE_FILENAME
452+
)
453+
if frame_segmentation_annotation_filepath.exists():
454+
with wrapped_open(
455+
frame_segmentation_annotation_filepath, 'r'
456+
) as fileptr:
457+
frame_segmentation_annotation_type = json.load(fileptr)
458+
self._sensor_frame_segmentation_annotation_type[sensor_name] = {}
459+
for (
460+
frame_num_str,
461+
annotation_type,
462+
) in frame_segmentation_annotation_type.items():
463+
self._sensor_frame_segmentation_annotation_type[sensor_name][
464+
int(frame_num_str)
465+
] = annotation_type
466+
else:
467+
self._sensor_frame_segmentation_annotation_type[sensor_name] = None
468+
469+
return self._sensor_frame_segmentation_annotation_type[input_sensor_name]
470+
471+
def lens_names(self, input_sensor_name: str) -> List[str]:
437472
"""Returns lens names in the session's sensor."""
438473

439-
# Just compute once.
474+
# Computed once.
440475
if not hasattr(self, '_sensor_lens_names'):
441476
self._sensor_lens_names = {}
442477
for sensor_name in self.sensor_names:
@@ -448,10 +483,10 @@ def lens_names(self, sensor_name: str) -> List[str]:
448483
lens_names.sort()
449484
self._sensor_lens_names[sensor_name] = lens_names
450485

451-
return self._sensor_lens_names[sensor_name]
486+
return self._sensor_lens_names[input_sensor_name]
452487

453488
def camera_poses(
454-
self, sensor_name: str
489+
self, input_sensor_name: str
455490
) -> List[Mapping[str, Union[bool, np.ndarray]]]:
456491
"""Returns camera poses corresponding the session's sensor."""
457492

@@ -465,7 +500,7 @@ def camera_poses(
465500
csv_file
466501
)
467502

468-
return self._sensor_camera_poses[sensor_name]
503+
return self._sensor_camera_poses[input_sensor_name]
469504

470505
def camera_intrinsics(
471506
self, sensor_name: str, lens_name: str
@@ -657,6 +692,9 @@ def _itersamples(
657692
else ''
658693
)
659694
sample[FEATURE_HAS_PANOPTIC_MASK_LABEL] = segmentation_mask_exists
695+
sample[FEATURE_SEGMENTATION_ANNOTATION_TYPE] = (
696+
ex.segmentation_annotation_type
697+
)
660698

661699
if self.config.feature_metric_depth_zed.to_include():
662700
metric_zed_depth_filename = ex.metric_depth_zed_filename(
@@ -714,6 +752,16 @@ def has_metric_depth_zed(self):
714752
def has_segmentation_mask(self):
715753
return self.segmentation_mask_filename(LEFT_LENS_NAME).exists()
716754

755+
@functools.cached_property
756+
def segmentation_annotation_type(self) -> str:
757+
segmentation_annotation_type = self.session.segmentation_annotation_type(
758+
self.sensor_name
759+
)
760+
if segmentation_annotation_type is None:
761+
return 'NA'
762+
else:
763+
return segmentation_annotation_type[self.frame_num]
764+
717765
@property
718766
def has_camera_pose(self):
719767
# all examples have some camera pose

sanpo_dataset/lib/common_test.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323

2424
FLAGS = flags.FLAGS
2525

26-
_REAL_SESSIONS_PATH = 'third_party/py/sanpo_dataset/lib/testdata/sanpo-real'
26+
_REAL_SESSIONS_PATH = (
27+
'sanpo_dataset/sanpo_dataset/lib/testdata/sanpo-real'
28+
)
2729
_SYNTHETIC_SESSIONS_PATH = (
28-
'third_party/py/sanpo_dataset/lib/testdata/sanpo-synthetic'
30+
'sanpo_dataset/sanpo_dataset/lib/testdata/sanpo-synthetic'
2931
)
3032
_REAL_SESSION_NAME = 'real_session'
3133
_SYNTHETIC_SESSION_NAME = 'synthetic_session'
@@ -45,19 +47,17 @@ def setUp(self):
4547
super().setUp()
4648

4749
self.real_sessions_base_dir = os.path.join(
48-
FLAGS.test_srcdir, 'goo''gle3', _REAL_SESSIONS_PATH
50+
FLAGS.test_srcdir, _REAL_SESSIONS_PATH
4951
)
5052
self.real_session_dir = os.path.join(
51-
FLAGS.test_srcdir, 'goo''gle3', _REAL_SESSIONS_PATH, _REAL_SESSION_NAME
53+
FLAGS.test_srcdir, _REAL_SESSIONS_PATH, _REAL_SESSION_NAME
5254
)
5355
self.synthetic_sessions_base_dir = os.path.join(
5456
FLAGS.test_srcdir,
55-
'goo''gle3',
5657
_SYNTHETIC_SESSIONS_PATH,
5758
)
5859
self.synthetic_session_dir = os.path.join(
5960
FLAGS.test_srcdir,
60-
'goo''gle3',
6161
_SYNTHETIC_SESSIONS_PATH,
6262
_SYNTHETIC_SESSION_NAME,
6363
)
@@ -133,6 +133,15 @@ def test_sanpo_real_session(self):
133133
else _N_REAL_CAMERA_HEAD_FRAMES
134134
),
135135
)
136+
segmentation_annotation_type = sanpo_session.segmentation_annotation_type(
137+
sensor_name
138+
)
139+
if sensor_name == 'camera_chest':
140+
# This session->sensor_name has 73 frames. There must be annotation type
141+
# for each of them.
142+
self.assertLen(segmentation_annotation_type, 73)
143+
else:
144+
self.assertIsNone(segmentation_annotation_type)
136145

137146
def test_frame_example(self):
138147
sanpo_session = common.SanpoSession(
@@ -147,6 +156,9 @@ def test_frame_example(self):
147156
self.assertTrue(frame_example.has_right_lens)
148157
if sensor_name == 'camera_chest':
149158
self.assertTrue(frame_example.has_segmentation_mask)
159+
self.assertEqual(
160+
frame_example.segmentation_annotation_type, 'HUMAN_ANNOTATED'
161+
)
150162
else:
151163
self.assertFalse(frame_example.has_segmentation_mask)
152164

@@ -164,6 +176,7 @@ def test_synthetic_frame_example(self):
164176
self.assertFalse(frame_example.has_metric_depth_zed)
165177
self.assertFalse(frame_example.has_right_lens)
166178
self.assertTrue(frame_example.has_segmentation_mask)
179+
self.assertEqual(frame_example.segmentation_annotation_type, 'SYNTHETIC')
167180

168181
def test_sanpo_session_list(self):
169182
session_ids_file = tempfile.mktemp()
@@ -197,7 +210,7 @@ def test_all_frame_iter_samples_panoptic_frame_mode(self):
197210
)
198211
count = 0
199212
for example in sanpo_session.all_frame_itersamples():
200-
self.assertLen(example, 7)
213+
self.assertLen(example, 8)
201214
count += 1
202215
self.assertEqual(count, _N_REAL_CAMERA_CHEST_FRAMES)
203216

@@ -210,7 +223,7 @@ def test_all_frame_iter_samples_all_optional_frame_mode(self):
210223
sanpo_session = common.SanpoSession(self.real_session_dir, config)
211224
count = 0
212225
for example in sanpo_session.all_frame_itersamples():
213-
self.assertLen(example, 14)
226+
self.assertLen(example, 15)
214227
count += 1
215228
self.assertEqual(
216229
count, _N_REAL_CAMERA_CHEST_FRAMES + _N_REAL_CAMERA_HEAD_FRAMES

sanpo_dataset/lib/tensorflow_dataset.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pathlib
2323
import random
2424
from typing import Any, Iterator, Mapping, Optional, Tuple, Union
25+
import warnings
2526

2627
import numpy as np
2728
from sanpo_dataset.lib import common
@@ -60,16 +61,15 @@ def __init__(
6061
f'[{target_h}, {target_w}] which looks like [width,height].'
6162
)
6263
if abs(target_w * 9 / 16 - target_h) > 1:
63-
raise ValueError(
64+
warnings.warn(
6465
f'The target shape [{target_h},{target_w}] aspect ratio must be'
65-
f' 16:9. Consider setting a target_shape of either [{target_h},'
66-
f' {int(target_h*16/9)}] or [{int(target_w*9/16)}, {target_w}],'
67-
' which would preserve the image aspect ratio.\n\nSANPO does not'
68-
' perform cropping or color augmentation for you because'
69-
' preprocessing strategies can vary by application.'
66+
' 16:9 or else camera intrinsics will be incorrect at the ratio'
67+
' you requested.'
7068
# TODO(kwilber): add a crop tool and uncomment the below lines
7169
# f'To crop the image, you can use the `common.crop_*` '
7270
# f'family of functions which properly adjust camera intrinsics.'
71+
,
72+
UserWarning,
7373
)
7474

7575
# TODO(kwilber): Verify the config.
@@ -177,6 +177,9 @@ def _get_tensor_signature(self) -> Mapping[str, tf.TensorSpec]:
177177
signature[common.FEATURE_HAS_PANOPTIC_MASK_LABEL] = tf.TensorSpec(
178178
shape=(), dtype=tf.bool
179179
)
180+
signature[common.FEATURE_SEGMENTATION_ANNOTATION_TYPE] = tf.TensorSpec(
181+
shape=(), dtype=tf.string
182+
)
180183

181184
if self.builder_config.feature_camera_pose.to_include():
182185
signature[common.FEATURE_TRACKING_STATE] = tf.TensorSpec(
@@ -202,7 +205,9 @@ def _maybe_resize(
202205
resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
203206
else:
204207
resize_method = tf.image.ResizeMethod.BILINEAR
205-
return tf.image.resize(tensor, [target_h, target_w], method=resize_method)
208+
return tf.image.resize_with_pad(
209+
tensor, target_h, target_w, method=resize_method
210+
)
206211

207212
def _tf_decode_image(self, filename: tf.Tensor) -> tf.Tensor:
208213
# can't use tf.io.decode_image here because
@@ -283,6 +288,9 @@ def _tf_load_panoptic_labels(
283288
common.FEATURE_HAS_PANOPTIC_MASK_LABEL: tf.convert_to_tensor(
284289
features[common.FEATURE_HAS_PANOPTIC_MASK_LABEL]
285290
),
291+
common.FEATURE_SEGMENTATION_ANNOTATION_TYPE: tf.convert_to_tensor(
292+
features[common.FEATURE_SEGMENTATION_ANNOTATION_TYPE]
293+
),
286294
}
287295

288296
return {}

0 commit comments

Comments
 (0)