44from copy import deepcopy
55
66from mmengine import Config
7+ from torch .utils .data import DataLoader
78
89from mmdeploy .apis .utils import build_task_processor
910from mmdeploy .utils import get_root_logger , load_config
@@ -31,9 +32,11 @@ def get_table(onnx_path: str,
3132 from quant_image_dataset import QuantizationImageDataset
3233 dataset = QuantizationImageDataset (
3334 path = image_dir , deploy_cfg = deploy_cfg , model_cfg = model_cfg )
34- calib_dataloader ['dataset' ] = dataset
35- dataloader = task_processor .build_dataloader (calib_dataloader )
36- # dataloader = DataLoader(dataset, batch_size=1)
35+
36+ def collate (data_batch ):
37+ return data_batch [0 ]
38+
39+ dataloader = DataLoader (dataset , batch_size = 1 , collate_fn = collate )
3740 else :
3841 dataset = task_processor .build_dataset (calib_dataloader ['dataset' ])
3942 calib_dataloader ['dataset' ] = dataset
@@ -44,16 +47,10 @@ def get_table(onnx_path: str,
4447 # get an available input shape randomly
4548 for _ , input_data in enumerate (dataloader ):
4649 input_data = data_preprocessor (input_data )
47- input_tensor = input_data [0 ]
48- if isinstance (input_tensor , list ):
49- input_shape = input_tensor [0 ].shape
50- collate_fn = lambda x : data_preprocessor (x [0 ])[0 ].to ( # noqa: E731
51- device )
52- else :
53- input_shape = input_tensor .shape
54- collate_fn = lambda x : data_preprocessor (x )[0 ].to ( # noqa: E731
55- device )
56- break
50+ input_tensor = input_data ['inputs' ]
51+ input_shape = input_tensor .shape
52+ collate_fn = lambda x : data_preprocessor (x )['inputs' ].to ( # noqa: E731
53+ device )
5754
5855 from ppq import QuantizationSettingFactory , TargetPlatform
5956 from ppq .api import export_ppq_graph , quantize_onnx_model
0 commit comments