Skip to content

Commit 2a6e727

Browse files
baorepoPillar1989
authored andcommitted
sscma: datasets: back to numpy backend
1 parent c9ffea8 commit 2a6e727

File tree

4 files changed

+856
-1109
lines changed

4 files changed

+856
-1109
lines changed

sscma/datasets/transforms/formatting.py

Lines changed: 65 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def to_tensor(data):
3939
"`Sequence`, `int` and `float`"
4040
)
4141

42-
43-
@TRANSFORMS.register_module()
4442
class PackDetInputs(BaseTransform):
4543
"""Pack the inputs data for the detection / semantic segmentation /
4644
panoptic segmentation.
@@ -70,54 +68,57 @@ class PackDetInputs(BaseTransform):
7068
Default: ``('img_id', 'img_path', 'ori_shape', 'img_shape',
7169
'scale_factor', 'flip', 'flip_direction')``
7270
"""
73-
7471
mapping_table = {
75-
"gt_bboxes": "bboxes",
76-
"gt_bboxes_labels": "labels",
77-
"gt_masks": "masks",
72+
'gt_bboxes': 'bboxes',
73+
'gt_bboxes_labels': 'labels',
74+
'gt_masks': 'masks',
75+
'gt_keypoints': 'keypoints',
76+
'gt_keypoints_visible': 'keypoints_visible'
7877
}
7978

80-
def __init__(
81-
self,
82-
meta_keys=(
83-
"img_id",
84-
"img_path",
85-
"ori_shape",
86-
"img_shape",
87-
"scale_factor",
88-
"flip",
89-
"flip_direction",
90-
),
91-
):
79+
def __init__(self,
80+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
81+
'scale_factor', 'flip', 'flip_direction')):
9282
self.meta_keys = meta_keys
9383

84+
9485
def transform(self, results: dict) -> dict:
9586
"""Method to pack the input data.
96-
9787
Args:
9888
results (dict): Result dict from the data pipeline.
99-
10089
Returns:
10190
dict:
102-
10391
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
10492
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
10593
sample.
10694
"""
10795
packed_results = dict()
96+
if 'img' in results:
97+
img = results['img']
98+
if len(img.shape) < 3:
99+
img = np.expand_dims(img, -1)
100+
# To improve the computational speed by by 3-5 times, apply:
101+
# If image is not contiguous, use
102+
# `numpy.transpose()` followed by `numpy.ascontiguousarray()`
103+
# If image is already contiguous, use
104+
# `torch.permute()` followed by `torch.contiguous()`
105+
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
106+
# for more details
107+
if not img.flags.c_contiguous:
108+
img = np.ascontiguousarray(img.transpose(2, 0, 1))
109+
img = to_tensor(img)
110+
else:
111+
img = to_tensor(img).permute(2, 0, 1).contiguous()
108112

109-
if not results.get("torch", False):
110-
results["img"] = V2F.to_dtype(
111-
V2F.to_image(results["img"].copy()), torch.uint8, scale=True
112-
)
113-
results["torch"] = True
114-
115-
if "img" in results:
116-
packed_results["inputs"] = results["img"]
113+
packed_results['inputs'] = img
117114

118-
if "gt_ignore_flags" in results:
119-
valid_idx = np.where(results["gt_ignore_flags"] == 0)[0]
120-
ignore_idx = np.where(results["gt_ignore_flags"] == 1)[0]
115+
if 'gt_ignore_flags' in results:
116+
valid_idx = np.where(results['gt_ignore_flags'] == 0)[0]
117+
ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0]
118+
if 'gt_keypoints' in results:
119+
results['gt_keypoints_visible'] = results[
120+
'gt_keypoints'].keypoints_visible
121+
results['gt_keypoints'] = results['gt_keypoints'].keypoints
121122

122123
data_sample = DetDataSample()
123124
instance_data = InstanceData()
@@ -126,59 +127,60 @@ def transform(self, results: dict) -> dict:
126127
for key in self.mapping_table.keys():
127128
if key not in results:
128129
continue
129-
if key == "gt_masks" or isinstance(results[key], BaseBoxes):
130-
if "gt_ignore_flags" in results:
131-
instance_data[self.mapping_table[key]] = results[key][valid_idx]
132-
ignore_instance_data[self.mapping_table[key]] = results[key][
133-
ignore_idx
134-
]
130+
if key == 'gt_masks' or isinstance(results[key], BaseBoxes):
131+
if 'gt_ignore_flags' in results:
132+
instance_data[
133+
self.mapping_table[key]] = results[key][valid_idx]
134+
ignore_instance_data[
135+
self.mapping_table[key]] = results[key][ignore_idx]
135136
else:
136137
instance_data[self.mapping_table[key]] = results[key]
137138
else:
138-
if "gt_ignore_flags" in results:
139+
if 'gt_ignore_flags' in results:
139140
instance_data[self.mapping_table[key]] = to_tensor(
140-
results[key][valid_idx]
141-
)
141+
results[key][valid_idx])
142142
ignore_instance_data[self.mapping_table[key]] = to_tensor(
143-
results[key][ignore_idx]
144-
)
143+
results[key][ignore_idx])
145144
else:
146-
instance_data[self.mapping_table[key]] = to_tensor(results[key])
145+
instance_data[self.mapping_table[key]] = to_tensor(
146+
results[key])
147147
data_sample.gt_instances = instance_data
148148
data_sample.ignored_instances = ignore_instance_data
149149

150-
if "proposals" in results:
151-
proposals = InstanceData(
152-
bboxes=to_tensor(results["proposals"]),
153-
scores=to_tensor(results["proposals_scores"]),
154-
)
155-
data_sample.proposals = proposals
156-
157-
if "gt_seg_map" in results:
150+
if 'gt_seg_map' in results:
158151
gt_sem_seg_data = dict(
159-
sem_seg=to_tensor(results["gt_seg_map"][None, ...].copy())
160-
)
161-
gt_sem_seg_data = PixelData(**gt_sem_seg_data)
162-
if "ignore_index" in results:
163-
metainfo = dict(ignore_index=results["ignore_index"])
164-
gt_sem_seg_data.set_metainfo(metainfo)
165-
data_sample.gt_sem_seg = gt_sem_seg_data
152+
sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy()))
153+
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
154+
155+
# In order to unify the support for the overlap mask annotations
156+
# i.e. mask overlap annotations in (h,w) format,
157+
# we use the gt_panoptic_seg field to unify the modeling
158+
if 'gt_panoptic_seg' in results:
159+
data_sample.gt_panoptic_seg = PixelData(
160+
pan_seg=results['gt_panoptic_seg'])
166161

167162
img_meta = {}
168163
for key in self.meta_keys:
169-
if key in results:
170-
img_meta[key] = results[key]
164+
assert key in results, f'`{key}` is not found in `results`, ' \
165+
f'the valid keys are {list(results)}.'
166+
img_meta[key] = results[key]
167+
171168
data_sample.set_metainfo(img_meta)
172-
packed_results["data_samples"] = data_sample
169+
packed_results['data_samples'] = data_sample
173170

174171
return packed_results
175172

173+
176174
def __repr__(self) -> str:
177175
repr_str = self.__class__.__name__
178-
repr_str += f"(meta_keys={self.meta_keys})"
176+
repr_str += f'(meta_keys={self.meta_keys})'
179177
return repr_str
180178

181179

180+
181+
182+
183+
182184
class PackInputs(BaseTransform):
183185
"""Pack the inputs data.
184186

0 commit comments

Comments
 (0)