@@ -72,14 +72,14 @@ def load_calibration_dataset(directory: str, input_shape: List[int], batchsize:
7272
7373def quantize (working_directory : str , setting : QuantizationSetting , model_type : NetworkFramework ,
7474 executing_device : str , input_shape : List [int ], target_platform : TargetPlatform ,
75- dataloader : DataLoader , calib_steps : int = 32 ) -> BaseGraph :
75+ dataloader : DataLoader , calib_steps : int = 32 ) -> BaseGraph :
7676 if model_type == NetworkFramework .ONNX :
7777 if not os .path .exists (os .path .join (working_directory , 'model.onnx' )):
7878 raise FileNotFoundError (f'无法找到你的模型: { os .path .join (working_directory , "model.onnx" )} ,'
7979 '如果你使用caffe的模型,请设置MODEL_TYPE为CAFFE' )
8080 return quantize_onnx_model (
8181 onnx_import_file = os .path .join (working_directory , 'model.onnx' ),
82- calib_dataloader = dataloader , calib_steps = 32 , input_shape = input_shape , setting = setting ,
82+ calib_dataloader = dataloader , calib_steps = calib_steps , input_shape = input_shape , setting = setting ,
8383 platform = target_platform , device = executing_device , collate_fn = lambda x : x .to (executing_device )
8484 )
8585 if model_type == NetworkFramework .CAFFE :
0 commit comments