@@ -51,11 +51,11 @@ def load_model(suite: str, model_name: str):
51
51
raise ValueError (msg )
52
52
53
53
54
- def load_calibration_dataset (dataset_path : str , suite : str , model : torch .nn .Module ):
54
+ def load_calibration_dataset (dataset_path : str , suite : str , model : torch .nn .Module , model_name : str ):
55
55
val_dir = f"{ dataset_path } /val"
56
56
57
57
if suite == "torchvision" :
58
- transform = torchvision_models .get_model_weights (model . name ) .transforms ()
58
+ transform = torchvision_models .get_model_weights (model_name ). DEFAULT .transforms ()
59
59
else :
60
60
transform = create_transform (** resolve_data_config (model .pretrained_cfg , model = model ))
61
61
@@ -87,14 +87,16 @@ def dump_inputs(calibration_dataset, dest_path):
87
87
return input_files , targets
88
88
89
89
90
- def main (suite : str , model_name : str , input_shape , quantize : bool , dataset_path : str , device : str ):
90
+ def main (suite : str , model_name : str , input_shape , quantize : bool , validate : bool , dataset_path : str , device : str ):
91
91
# Ensure input_shape is a tuple
92
92
if isinstance (input_shape , list ):
93
93
input_shape = tuple (input_shape )
94
94
elif not isinstance (input_shape , tuple ):
95
95
msg = "Input shape must be a list or tuple."
96
96
raise ValueError (msg )
97
97
98
+ calibration_dataset = None
99
+
98
100
# Load the selected model
99
101
model = load_model (suite , model_name )
100
102
model = model .eval ()
@@ -114,7 +116,7 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path:
114
116
if not dataset_path :
115
117
msg = "Quantization requires a calibration dataset."
116
118
raise ValueError (msg )
117
- calibration_dataset = load_calibration_dataset (dataset_path , suite , model )
119
+ calibration_dataset = load_calibration_dataset (dataset_path , suite , model , model_name )
118
120
119
121
captured_model = aten_dialect .module ()
120
122
quantizer = OpenVINOQuantizer ()
@@ -146,12 +148,15 @@ def transform(x):
146
148
exec_prog = lowered_module .to_executorch (config = executorch .exir .ExecutorchBackendConfig ())
147
149
148
150
# Serialize and save it to a file
149
- model_name = f"{ model_name } _{ 'int8' if quantize else 'fp32' } .pte"
150
- with open (model_name , "wb" ) as file :
151
+ model_file_name = f"{ model_name } _{ 'int8' if quantize else 'fp32' } .pte"
152
+ with open (model_file_name , "wb" ) as file :
151
153
exec_prog .write_to_file (file )
152
- print (f"Model exported and saved as { model_name } on { device } ." )
154
+ print (f"Model exported and saved as { model_file_name } on { device } ." )
155
+
156
+ if validate :
157
+ if calibration_dataset is None :
158
+ calibration_dataset = load_calibration_dataset (dataset_path , suite , model , model_name )
153
159
154
- if quantize :
155
160
print ("Start validation of the quantized model:" )
156
161
# 1: Dump inputs
157
162
dest_path = Path ("tmp_inputs" )
@@ -172,18 +177,17 @@ def transform(x):
172
177
subprocess .run (
173
178
[
174
179
"../../../cmake-openvino-out/examples/openvino/openvino_executor_runner" ,
175
- f"--model_path={ model_name } " ,
180
+ f"--model_path={ model_file_name } " ,
176
181
f"--input_list_path={ inp_list_file } " ,
177
182
f"--output_folder_path={ out_path } " ,
178
183
]
179
184
)
180
185
181
186
# 3: load the outputs and compare with the targets
182
-
183
187
predictions = []
184
188
for i in range (len (input_files )):
185
189
tensor = np .fromfile (out_path / f"output_{ i } _0.raw" , dtype = np .float32 )
186
- predictions .append (torch .tensor ( np . argmax (tensor )))
190
+ predictions .append (torch .argmax ( torch . tensor (tensor )))
187
191
188
192
acc_top1 = accuracy_score (predictions , targets )
189
193
print (f"acc@1: { acc_top1 } " )
@@ -207,6 +211,11 @@ def transform(x):
207
211
help = "Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224))." ,
208
212
)
209
213
parser .add_argument ("--quantize" , action = "store_true" , help = "Enable model quantization." )
214
+ parser .add_argument (
215
+ "--validate" ,
216
+ action = "store_true" ,
217
+ help = "Enable model validation. --dataset argument is requred for the validation." ,
218
+ )
210
219
parser .add_argument ("--dataset" , type = str , help = "Path to the calibration dataset." )
211
220
parser .add_argument (
212
221
"--device" ,
@@ -219,4 +228,4 @@ def transform(x):
219
228
220
229
# Run the main function with parsed arguments
221
230
with nncf .torch .disable_patching ():
222
- main (args .suite , args .model , args .input_shape , args .quantize , args .dataset , args .device )
231
+ main (args .suite , args .model , args .input_shape , args .quantize , args .validate , args . dataset , args .device )
0 commit comments