@@ -36,14 +36,20 @@ def process_model_config(
3636 cfg = copy .deepcopy (model_cfg )
3737 test_pipeline = cfg .test_dataloader .dataset .pipeline
3838 data_preprocessor = cfg .model .data_preprocessor
39- codec = cfg .codec
40- if isinstance (codec , list ):
41- codec = codec [- 1 ]
42- input_size = codec ['input_size' ] if input_shape is None else input_shape
39+ codec = getattr (cfg , 'codec' , None )
40+ if codec is not None :
41+ if isinstance (codec , list ):
42+ codec = codec [- 1 ]
43+ input_size = codec ['input_size' ] if input_shape is None \
44+ else input_shape
45+ else :
46+ input_size = cfg .img_scale
47+
4348 test_pipeline [0 ] = dict (type = 'LoadImageFromFile' )
4449 for i in reversed (range (len (test_pipeline ))):
4550 trans = test_pipeline [i ]
46- if trans ['type' ] == 'PackPoseInputs' :
51+ if trans ['type' ] == 'PackPoseInputs' or trans [
52+ 'type' ] == 'PackDetPoseInputs' :
4753 test_pipeline .pop (i )
4854 elif trans ['type' ] == 'GetBBoxCenterScale' :
4955 trans ['type' ] = 'TopDownGetBboxCenterScale'
@@ -53,22 +59,37 @@ def process_model_config(
5359 trans ['type' ] = 'TopDownAffine'
5460 trans ['image_size' ] = input_size
5561 trans .pop ('input_size' )
56-
57- test_pipeline .append (
58- dict (
59- type = 'Normalize' ,
60- mean = data_preprocessor .mean ,
61- std = data_preprocessor .std ,
62- to_rgb = data_preprocessor .bgr_to_rgb ))
62+ elif trans ['type' ][:6 ] == 'mmdet.' :
63+ trans ['type' ] = trans ['type' ][6 :]
64+
65+ # DetDataPreprocessor does not have mean, std, bgr_to_rgb
66+ # TODO: implement PoseToDetConverter and PackDetPoseInputs in c++
67+ if data_preprocessor .type != 'mmdet.DetDataPreprocessor' :
68+ test_pipeline .append (
69+ dict (
70+ type = 'Normalize' ,
71+ mean = data_preprocessor .mean ,
72+ std = data_preprocessor .std ,
73+ to_rgb = data_preprocessor .bgr_to_rgb ))
6374 test_pipeline .append (dict (type = 'ImageToTensor' , keys = ['img' ]))
64- test_pipeline .append (
65- dict (
66- type = 'Collect' ,
67- keys = ['img' ],
68- meta_keys = [
69- 'img_shape' , 'pad_shape' , 'ori_shape' , 'img_norm_cfg' ,
70- 'scale_factor' , 'bbox_score' , 'center' , 'scale'
71- ]))
75+ if data_preprocessor .type != 'mmdet.DetDataPreprocessor' :
76+ test_pipeline .append (
77+ dict (
78+ type = 'Collect' ,
79+ keys = ['img' ],
80+ meta_keys = [
81+ 'img_shape' , 'pad_shape' , 'ori_shape' , 'img_norm_cfg' ,
82+ 'scale_factor' , 'bbox_score' , 'center' , 'scale'
83+ ]))
84+ else :
85+ test_pipeline .append (
86+ dict (
87+ type = 'Collect' ,
88+ keys = ['img' ],
89+ meta_keys = [
90+ 'id' , 'img_id' , 'img_path' , 'ori_shape' , 'img_shape' ,
91+ 'scale_factor' , 'flip_indices'
92+ ]))
7293
7394 cfg .test_dataloader .dataset .pipeline = test_pipeline
7495 return cfg
@@ -345,13 +366,19 @@ def get_preprocess(self, *args, **kwargs) -> Dict:
345366
346367 def get_postprocess (self , * args , ** kwargs ) -> Dict :
347368 """Get the postprocess information for SDK."""
348- codec = self .model_cfg .codec
349- if isinstance (codec , (list , tuple )):
350- codec = codec [- 1 ]
351- component = 'UNKNOWN'
369+ codec = getattr (self .model_cfg , 'codec' , None )
352370 params = copy .deepcopy (self .model_cfg .model .test_cfg )
353- params .update (codec )
354- if self .model_cfg .model .type == 'TopdownPoseEstimator' :
371+ component = 'UNKNOWN'
372+ if codec is not None :
373+ if isinstance (codec , (list , tuple )):
374+ codec = codec [- 1 ]
375+ params .update (codec )
376+ else :
377+ # TODO: implement this in c++
378+ component = 'YOLOXPoseHeadDecode'
379+
380+ if self .model_cfg .model .type == 'TopdownPoseEstimator' \
381+ and codec is not None :
355382 component = 'TopdownHeatmapSimpleHeadDecode'
356383 if codec .type == 'MSRAHeatmap' :
357384 params ['post_process' ] = 'default'
0 commit comments