44# except in compliance with the License. See the license file found in the
55# LICENSE file in the root directory of this source tree.
66
7+ # mypy: disable-error-code=import-untyped
8+
79import argparse
810import time
11+ from typing import cast , List , Optional
912
1013import executorch
1114
1518import torchvision .models as torchvision_models
1619from executorch .backends .openvino .partitioner import OpenvinoPartitioner
1720from executorch .backends .openvino .quantizer import quantize_model
18- from executorch .exir import EdgeProgramManager , to_edge_transform_and_lower
21+ from executorch .exir import (
22+ EdgeProgramManager ,
23+ ExecutorchProgramManager ,
24+ to_edge_transform_and_lower ,
25+ )
1926from executorch .exir .backend .backend_details import CompileSpec
2027from executorch .runtime import Runtime
2128from sklearn .metrics import accuracy_score
@@ -102,7 +109,7 @@ def load_calibration_dataset(
102109
103110
104111def infer_model (
105- exec_prog : EdgeProgramManager ,
112+ exec_prog : ExecutorchProgramManager ,
106113 inputs ,
107114 num_iter : int ,
108115 warmup_iter : int ,
@@ -111,7 +118,7 @@ def infer_model(
111118 """
112119 Executes inference and reports the average timing.
113120
114- :param exec_prog: EdgeProgramManager of the lowered model
121+ :param exec_prog: ExecutorchProgramManager of the lowered model
115122 :param inputs: The inputs for the model.
116123 :param num_iter: The number of iterations to execute inference for timing.
117124 :param warmup_iter: The number of iterations to execute inference for warmup before timing.
@@ -122,8 +129,11 @@ def infer_model(
122129 runtime = Runtime .get ()
123130 program = runtime .load_program (exec_prog .buffer )
124131 method = program .load_method ("forward" )
132+ if method is None :
133+ raise ValueError ("Load method failed" )
125134
126135 # Execute warmup
136+ out = None
127137 for _i in range (warmup_iter ):
128138 out = method .execute (inputs )
129139
@@ -137,34 +147,38 @@ def infer_model(
137147
138148 # Save output tensor as raw tensor file
139149 if output_path :
150+ assert out is not None
140151 torch .save (out , output_path )
141152
142153 # Return average inference timing
143154 return time_total / float (num_iter )
144155
145156
146157def validate_model (
147- exec_prog : EdgeProgramManager , calibration_dataset : torch .utils .data .DataLoader
158+ exec_prog : ExecutorchProgramManager ,
159+ calibration_dataset : torch .utils .data .DataLoader ,
148160) -> float :
149161 """
150162 Validates the model using the calibration dataset.
151163
152- :param exec_prog: EdgeProgramManager of the lowered model
164+ :param exec_prog: ExecutorchProgramManager of the lowered model
153165 :param calibration_dataset: A DataLoader containing calibration data.
154166 :return: The accuracy score of the model.
155167 """
156168 # Load model from buffer
157169 runtime = Runtime .get ()
158170 program = runtime .load_program (exec_prog .buffer )
159171 method = program .load_method ("forward" )
172+ if method is None :
173+ raise ValueError ("Load method failed" )
160174
161175 # Iterate over the dataset and run the executor
162- predictions = []
176+ predictions : List [ int ] = []
163177 targets = []
164178 for _idx , data in enumerate (calibration_dataset ):
165179 feature , target = data
166180 targets .extend (target )
167- out = method .execute ((feature ,))
181+ out = list ( method .execute ((feature ,) ))
168182 predictions .extend (torch .stack (out ).reshape (- 1 , 1000 ).argmax (- 1 ))
169183
170184 # Check accuracy
@@ -213,12 +227,18 @@ def main( # noqa: C901
213227 model = load_model (suite , model_name )
214228 model = model .eval ()
215229
230+ calibration_dataset : Optional [torch .utils .data .DataLoader ] = None
231+
216232 if dataset_path :
217233 calibration_dataset = load_calibration_dataset (
218234 dataset_path , batch_size , suite , model , model_name
219235 )
220- input_shape = tuple (next (iter (calibration_dataset ))[0 ].shape )
221- print (f"Input shape retrieved from the model config: { input_shape } " )
236+ if calibration_dataset is not None :
237+ input_shape = tuple (next (iter (calibration_dataset ))[0 ].shape )
238+ print (f"Input shape retrieved from the model config: { input_shape } " )
239+ else :
240+ msg = "Quantization requires a valid calibration dataset"
241+ raise ValueError (msg )
222242 # Ensure input_shape is a tuple
223243 elif isinstance (input_shape , (list , tuple )):
224244 input_shape = tuple (input_shape )
@@ -240,7 +260,7 @@ def main( # noqa: C901
240260 # Export the model to the aten dialect
241261 aten_dialect : ExportedProgram = export (model , example_args )
242262
243- if quantize :
263+ if quantize and calibration_dataset :
244264 if suite == "huggingface" :
245265 msg = f"Quantization of { suite } models did not support yet."
246266 raise ValueError (msg )
@@ -251,20 +271,20 @@ def main( # noqa: C901
251271 raise ValueError (msg )
252272
253273 subset_size = 300
254- batch_size = calibration_dataset .batch_size
274+ batch_size = calibration_dataset .batch_size or 1
255275 subset_size = (subset_size // batch_size ) + int (subset_size % batch_size > 0 )
256276
257277 def transform_fn (x ):
258278 return x [0 ]
259279
260280 quantized_model = quantize_model (
261- aten_dialect .module (),
281+ cast ( torch . fx . GraphModule , aten_dialect .module () ),
262282 calibration_dataset ,
263283 subset_size = subset_size ,
264284 transform_fn = transform_fn ,
265285 )
266286
267- aten_dialect : ExportedProgram = export (quantized_model , example_args )
287+ aten_dialect = export (quantized_model , example_args )
268288
269289 # Convert to edge dialect and lower the module to the backend with a custom partitioner
270290 compile_spec = [CompileSpec ("device" , device .encode ())]
@@ -288,7 +308,7 @@ def transform_fn(x):
288308 exec_prog .write_to_file (file )
289309 print (f"Model exported and saved as { model_file_name } on { device } ." )
290310
291- if validate :
311+ if validate and calibration_dataset :
292312 if suite == "huggingface" :
293313 msg = f"Validation of { suite } models did not support yet."
294314 raise ValueError (msg )
0 commit comments