@@ -39,8 +39,6 @@ def to_tensor(data):
3939 "`Sequence`, `int` and `float`"
4040 )
4141
42-
43- @TRANSFORMS .register_module ()
4442class PackDetInputs (BaseTransform ):
4543 """Pack the inputs data for the detection / semantic segmentation /
4644 panoptic segmentation.
@@ -70,54 +68,57 @@ class PackDetInputs(BaseTransform):
7068 Default: ``('img_id', 'img_path', 'ori_shape', 'img_shape',
7169 'scale_factor', 'flip', 'flip_direction')``
7270 """
73-
7471 mapping_table = {
75- "gt_bboxes" : "bboxes" ,
76- "gt_bboxes_labels" : "labels" ,
77- "gt_masks" : "masks" ,
72+ 'gt_bboxes' : 'bboxes' ,
73+ 'gt_bboxes_labels' : 'labels' ,
74+ 'gt_masks' : 'masks' ,
75+ 'gt_keypoints' : 'keypoints' ,
76+ 'gt_keypoints_visible' : 'keypoints_visible'
7877 }
7978
80- def __init__ (
81- self ,
82- meta_keys = (
83- "img_id" ,
84- "img_path" ,
85- "ori_shape" ,
86- "img_shape" ,
87- "scale_factor" ,
88- "flip" ,
89- "flip_direction" ,
90- ),
91- ):
79+ def __init__ (self ,
80+ meta_keys = ('img_id' , 'img_path' , 'ori_shape' , 'img_shape' ,
81+ 'scale_factor' , 'flip' , 'flip_direction' )):
9282 self .meta_keys = meta_keys
9383
84+
9485 def transform (self , results : dict ) -> dict :
9586 """Method to pack the input data.
96-
9787 Args:
9888 results (dict): Result dict from the data pipeline.
99-
10089 Returns:
10190 dict:
102-
10391 - 'inputs' (obj:`torch.Tensor`): The forward data of models.
10492 - 'data_sample' (obj:`DetDataSample`): The annotation info of the
10593 sample.
10694 """
10795 packed_results = dict ()
96+ if 'img' in results :
97+ img = results ['img' ]
98+ if len (img .shape ) < 3 :
99+ img = np .expand_dims (img , - 1 )
100+ # To improve the computational speed by by 3-5 times, apply:
101+ # If image is not contiguous, use
102+ # `numpy.transpose()` followed by `numpy.ascontiguousarray()`
103+ # If image is already contiguous, use
104+ # `torch.permute()` followed by `torch.contiguous()`
105+ # Refer to https://github.com/open-mmlab/mmdetection/pull/9533
106+ # for more details
107+ if not img .flags .c_contiguous :
108+ img = np .ascontiguousarray (img .transpose (2 , 0 , 1 ))
109+ img = to_tensor (img )
110+ else :
111+ img = to_tensor (img ).permute (2 , 0 , 1 ).contiguous ()
108112
109- if not results .get ("torch" , False ):
110- results ["img" ] = V2F .to_dtype (
111- V2F .to_image (results ["img" ].copy ()), torch .uint8 , scale = True
112- )
113- results ["torch" ] = True
114-
115- if "img" in results :
116- packed_results ["inputs" ] = results ["img" ]
113+ packed_results ['inputs' ] = img
117114
118- if "gt_ignore_flags" in results :
119- valid_idx = np .where (results ["gt_ignore_flags" ] == 0 )[0 ]
120- ignore_idx = np .where (results ["gt_ignore_flags" ] == 1 )[0 ]
115+ if 'gt_ignore_flags' in results :
116+ valid_idx = np .where (results ['gt_ignore_flags' ] == 0 )[0 ]
117+ ignore_idx = np .where (results ['gt_ignore_flags' ] == 1 )[0 ]
118+ if 'gt_keypoints' in results :
119+ results ['gt_keypoints_visible' ] = results [
120+ 'gt_keypoints' ].keypoints_visible
121+ results ['gt_keypoints' ] = results ['gt_keypoints' ].keypoints
121122
122123 data_sample = DetDataSample ()
123124 instance_data = InstanceData ()
@@ -126,59 +127,60 @@ def transform(self, results: dict) -> dict:
126127 for key in self .mapping_table .keys ():
127128 if key not in results :
128129 continue
129- if key == " gt_masks" or isinstance (results [key ], BaseBoxes ):
130- if " gt_ignore_flags" in results :
131- instance_data [self . mapping_table [ key ]] = results [ key ][ valid_idx ]
132- ignore_instance_data [ self .mapping_table [key ]] = results [key ][
133- ignore_idx
134- ]
130+ if key == ' gt_masks' or isinstance (results [key ], BaseBoxes ):
131+ if ' gt_ignore_flags' in results :
132+ instance_data [
133+ self .mapping_table [key ]] = results [key ][valid_idx ]
134+ ignore_instance_data [
135+ self . mapping_table [ key ]] = results [ key ][ ignore_idx ]
135136 else :
136137 instance_data [self .mapping_table [key ]] = results [key ]
137138 else :
138- if " gt_ignore_flags" in results :
139+ if ' gt_ignore_flags' in results :
139140 instance_data [self .mapping_table [key ]] = to_tensor (
140- results [key ][valid_idx ]
141- )
141+ results [key ][valid_idx ])
142142 ignore_instance_data [self .mapping_table [key ]] = to_tensor (
143- results [key ][ignore_idx ]
144- )
143+ results [key ][ignore_idx ])
145144 else :
146- instance_data [self .mapping_table [key ]] = to_tensor (results [key ])
145+ instance_data [self .mapping_table [key ]] = to_tensor (
146+ results [key ])
147147 data_sample .gt_instances = instance_data
148148 data_sample .ignored_instances = ignore_instance_data
149149
150- if "proposals" in results :
151- proposals = InstanceData (
152- bboxes = to_tensor (results ["proposals" ]),
153- scores = to_tensor (results ["proposals_scores" ]),
154- )
155- data_sample .proposals = proposals
156-
157- if "gt_seg_map" in results :
150+ if 'gt_seg_map' in results :
158151 gt_sem_seg_data = dict (
159- sem_seg = to_tensor (results ["gt_seg_map" ][None , ...].copy ())
160- )
161- gt_sem_seg_data = PixelData (** gt_sem_seg_data )
162- if "ignore_index" in results :
163- metainfo = dict (ignore_index = results ["ignore_index" ])
164- gt_sem_seg_data .set_metainfo (metainfo )
165- data_sample .gt_sem_seg = gt_sem_seg_data
152+ sem_seg = to_tensor (results ['gt_seg_map' ][None , ...].copy ()))
153+ data_sample .gt_sem_seg = PixelData (** gt_sem_seg_data )
154+
155+ # In order to unify the support for the overlap mask annotations
156+ # i.e. mask overlap annotations in (h,w) format,
157+ # we use the gt_panoptic_seg field to unify the modeling
158+ if 'gt_panoptic_seg' in results :
159+ data_sample .gt_panoptic_seg = PixelData (
160+ pan_seg = results ['gt_panoptic_seg' ])
166161
167162 img_meta = {}
168163 for key in self .meta_keys :
169- if key in results :
170- img_meta [key ] = results [key ]
164+ assert key in results , f'`{ key } ` is not found in `results`, ' \
165+ f'the valid keys are { list (results )} .'
166+ img_meta [key ] = results [key ]
167+
171168 data_sample .set_metainfo (img_meta )
172- packed_results [" data_samples" ] = data_sample
169+ packed_results [' data_samples' ] = data_sample
173170
174171 return packed_results
175172
173+
176174 def __repr__ (self ) -> str :
177175 repr_str = self .__class__ .__name__
178- repr_str += f" (meta_keys={ self .meta_keys } )"
176+ repr_str += f' (meta_keys={ self .meta_keys } )'
179177 return repr_str
180178
181179
180+
181+
182+
183+
182184class PackInputs (BaseTransform ):
183185 """Pack the inputs data.
184186
0 commit comments