@@ -51,7 +51,7 @@ def load_model(suite: str, model_name: str):
51
51
raise ValueError (msg )
52
52
53
53
54
- def load_calibration_dataset (dataset_path : str , suite : str , model : torch .nn .Module , model_name : str ):
54
+ def load_calibration_dataset (dataset_path : str , batch_size : int , suite : str , model : torch .nn .Module , model_name : str ):
55
55
val_dir = f"{ dataset_path } /val"
56
56
57
57
if suite == "torchvision" :
@@ -62,7 +62,7 @@ def load_calibration_dataset(dataset_path: str, suite: str, model: torch.nn.Modu
62
62
val_dataset = datasets .ImageFolder (val_dir , transform = transform )
63
63
64
64
calibration_dataset = torch .utils .data .DataLoader (
65
- val_dataset , batch_size = 1 , shuffle = False , num_workers = 0 , pin_memory = True
65
+ val_dataset , batch_size = batch_size , shuffle = False , num_workers = 0 , pin_memory = True
66
66
)
67
67
68
68
return calibration_dataset
@@ -77,7 +77,7 @@ def dump_inputs(calibration_dataset, dest_path):
77
77
input_files , targets = [], []
78
78
for idx , data in enumerate (calibration_dataset ):
79
79
feature , target = data
80
- targets .append (target )
80
+ targets .extend (target )
81
81
file_name = f"{ dest_path } /input_{ idx } _0.raw"
82
82
if not isinstance (feature , torch .Tensor ):
83
83
feature = torch .tensor (feature )
@@ -87,13 +87,22 @@ def dump_inputs(calibration_dataset, dest_path):
87
87
return input_files , targets
88
88
89
89
90
- def main (suite : str , model_name : str , input_shape , quantize : bool , validate : bool , dataset_path : str , device : str ):
90
+ def main (
91
+ suite : str ,
92
+ model_name : str ,
93
+ input_shape ,
94
+ quantize : bool ,
95
+ validate : bool ,
96
+ dataset_path : str ,
97
+ device : str ,
98
+ batch_size : int ,
99
+ ):
91
100
# Load the selected model
92
101
model = load_model (suite , model_name )
93
102
model = model .eval ()
94
103
95
104
if dataset_path :
96
- calibration_dataset = load_calibration_dataset (dataset_path , suite , model , model_name )
105
+ calibration_dataset = load_calibration_dataset (dataset_path , batch_size , suite , model , model_name )
97
106
input_shape = tuple (next (iter (calibration_dataset ))[0 ].shape )
98
107
print (f"Input shape retrieved from the model config: { input_shape } " )
99
108
# Ensure input_shape is a tuple
@@ -192,7 +201,7 @@ def transform(x):
192
201
predictions = []
193
202
for i in range (len (input_files )):
194
203
tensor = np .fromfile (out_path / f"output_{ i } _0.raw" , dtype = np .float32 )
195
- predictions .append (torch .argmax ( torch . tensor (tensor )))
204
+ predictions .extend (torch .tensor (tensor ). reshape ( - 1 , 1000 ). argmax ( - 1 ))
196
205
197
206
acc_top1 = accuracy_score (predictions , targets )
198
207
print (f"acc@1: { acc_top1 } " )
@@ -214,6 +223,13 @@ def transform(x):
214
223
type = eval ,
215
224
help = "Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224))." ,
216
225
)
226
+ parser .add_argument (
227
+ "--batch_size" ,
228
+ type = int ,
229
+ default = 1 ,
230
+ help = "Batch size for the validation. Default batch_size == 1."
231
+ " The dataset length must be evenly divisible by the batch size." ,
232
+ )
217
233
parser .add_argument ("--quantize" , action = "store_true" , help = "Enable model quantization." )
218
234
parser .add_argument (
219
235
"--validate" ,
@@ -232,4 +248,13 @@ def transform(x):
232
248
233
249
# Run the main function with parsed arguments
234
250
with nncf .torch .disable_patching ():
235
- main (args .suite , args .model , args .input_shape , args .quantize , args .validate , args .dataset , args .device )
251
+ main (
252
+ args .suite ,
253
+ args .model ,
254
+ args .input_shape ,
255
+ args .quantize ,
256
+ args .validate ,
257
+ args .dataset ,
258
+ args .device ,
259
+ args .batch_size ,
260
+ )
0 commit comments