@@ -682,12 +682,12 @@ def set_input_size(self, input_size: Union[int, List[int], Tuple[int, int]]):
682682 self ._set_pipeline_size_value (pipelines , resize_ratio )
683683
684684 # Set model size
685- # - needed only for YOLOX
686685 model_cfg = self ._config .get ("model" , {})
686+ model_cfg ["input_size" ] = input_size
687687 if model_cfg .get ("type" , "" ) == "CustomYOLOX" :
688+ # - needed only for YOLOX
688689 if input_size [0 ] % 32 != 0 or input_size [1 ] % 32 != 0 :
689690 raise ValueError ("YOLOX should have input size being multiple of 32." )
690- model_cfg ["input_size" ] = input_size
691691
692692 @property
693693 def base_input_size (self ) -> Union [Tuple [int , int ], Dict [str , Tuple [int , int ]]]:
@@ -862,38 +862,28 @@ def _set_size_value(pipeline: Dict, attr: str, scale: Tuple[Union[int, float], U
862862 pipeline [attr ] = (round (pipeline [attr ][0 ] * scale [0 ]), round (pipeline [attr ][1 ] * scale [1 ]))
863863
864864 @staticmethod
865- def get_configured_input_size (
866- input_size_config : InputSizePreset = InputSizePreset .DEFAULT , model_ckpt : Optional [str ] = None
867- ) -> Optional [Tuple [int , int ]]:
868- """Get configurable input size configuration. If it doesn't exist, return None.
865+ def get_trained_input_size (model_ckpt : Optional [str ] = None ) -> Optional [Tuple [int , int ]]:
866+ """Get trained input size from checkpoint. If it doesn't exist, return None.
869867
870868 Args:
871- input_size_config (InputSizePreset, optional): Input size setting. Defaults to InputSizePreset.DEFAULT.
872869 model_ckpt (Optional[str], optional): Model weight to load. Defaults to None.
873870
874871 Returns:
875872 Optional[Tuple[int, int]]: Pair of width and height. If there is no input size configuration, return None.
876873 """
877- input_size = None
878- if input_size_config == InputSizePreset .DEFAULT :
879- if model_ckpt is None :
880- return None
881-
882- model_info = torch .load (model_ckpt , map_location = "cpu" )
883- for key in ["config" , "learning_parameters" , "input_size" , "value" ]:
884- if key not in model_info :
885- return None
886- model_info = model_info [key ]
887- input_size = model_info
888-
889- if input_size == InputSizePreset .DEFAULT .value :
890- return None
891- logger .info ("Given model weight was trained with {} input size." .format (input_size ))
874+ if model_ckpt is None :
875+ return None
892876
893- else :
894- input_size = input_size_config .value
877+ model_info = torch .load (model_ckpt , map_location = "cpu" )
878+ if model_info is None :
879+ return None
895880
896- return InputSizePreset .parse (input_size )
881+ input_size = model_info .get ("input_size" , None )
882+ if not input_size :
883+ return None
884+
885+ logger .info ("Given model weight was trained with {} input size." .format (input_size ))
886+ return input_size
897887
898888 @staticmethod
899889 def select_closest_size (input_size : Tuple [int , int ], preset_sizes : List [Tuple [int , int ]]):
0 commit comments