99
1010import argparse
1111from itertools import islice
12- from typing import Any , Iterator , Tuple
12+ from typing import Any , Dict , Iterator , Optional , Tuple
1313
1414import cv2
1515import executorch
2828 to_edge_transform_and_lower ,
2929)
3030from executorch .exir .backend .backend_details import CompileSpec
31+ from executorch .runtime import Runtime
3132from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
3233from torch .export .exported_program import ExportedProgram
3334from torch .fx .passes .graph_drawer import FxGraphDrawer
3435from ultralytics import YOLO
3536
37+ from ultralytics .data .utils import check_det_dataset
38+ from ultralytics .engine .validator import BaseValidator as Validator
39+ from ultralytics .utils .torch_utils import de_parallel
40+
3641
3742class CV2VideoIter :
3843 def __init__ (self , cap ) -> None :
@@ -204,17 +209,21 @@ def main(
204209 subset_size : int ,
205210 backend : str ,
206211 device : str ,
212+ val_dataset_yaml_path : Optional [str ],
207213):
208214 """
209215 Main function to load, quantize, and export an Yolo model model.
210216
211217 :param model_name: The name of the YOLO model to load.
218+ :param input_dims: Input dims to use for the export of a YOLO12 model.
212219 :param quantize: Whether to quantize the model.
213220 :param video_path: Path to the video to use for the calibration
221+ :param subset_size: Subset size for the quantized model calibration. The default value is 300.
214222 :param backend: The Executorch inference backend (e.g., "openvino", "xnnpack").
215223 :param device: The device to run the model on (e.g., "cpu", "gpu").
224+ :param val_dataset_yaml_path: Path to the validation dataset file in Ultralytics .yaml format.
225+ Performs validation if the path is not None, skips validation otherwise.
216226 """
217-
218227 # Load the selected model
219228 model = YOLO (model_name )
220229
@@ -267,6 +276,67 @@ def transform_fn(frame):
267276 exec_prog .write_to_file (file )
268277 print (f"Model exported and saved as { model_file_name } on { device } ." )
269278
279+ if val_dataset_yaml_path is not None :
280+ if input_dims != [640 , 640 ]:
281+ raise NotImplementedError (
282+ f"Validation with the custom input shape { input_dims } is not implmenented."
283+ " Please use the default --input_dims=[640, 640] for the validation."
284+ )
285+ stats = validate_yolo (model , exec_prog , val_dataset_yaml_path )
286+ for stat , value in stats .items ():
287+ print (f"{ stat } : { value } " )
288+
289+
290+ def _prepare_validation (
291+ model : YOLO , dataset_yaml_path : str
292+ ) -> Tuple [Validator , torch .utils .data .DataLoader ]:
293+ custom = {"rect" : False , "batch" : 1 } # method defaults
294+ args = {
295+ ** model .overrides ,
296+ ** custom ,
297+ "mode" : "val" ,
298+ } # highest priority args on the right
299+
300+ validator = model ._smart_load ("validator" )(args = args , _callbacks = model .callbacks )
301+ stride = 32 # default stride
302+ validator .stride = stride # used in get_dataloader() for padding
303+ validator .data = check_det_dataset (dataset_yaml_path )
304+ validator .init_metrics (de_parallel (model ))
305+
306+ data_loader = validator .get_dataloader (
307+ validator .data .get (validator .args .split ), validator .args .batch
308+ )
309+
310+ return validator , data_loader
311+
312+
313+ def validate_yolo (
314+ model : YOLO , exec_prog : ExecutorchProgramManager , dataset_yaml_path : str
315+ ) -> Dict [str , float ]:
316+ """
317+ Runs validation on a YOLO model using an ExecuTorch program and a dataset in Ultralytics format.
318+
319+ :param model: The YOLO model instance to validate.
320+ :param exec_prog: The ExecuTorch program manager containing the compiled model.
321+ :param dataset_yaml_path: Path to the validation dataset file in Ultralytics .yaml format.
322+ :return: Dictionary of validation statistics computed over the dataset.
323+ """
324+ # Load model from buffer
325+ runtime = Runtime .get ()
326+ program = runtime .load_program (exec_prog .buffer )
327+ method = program .load_method ("forward" )
328+ if method is None :
329+ raise ValueError ("Load method failed" )
330+ validator , data_loader = _prepare_validation (model , dataset_yaml_path )
331+ print (f"Start validation on { dataset_yaml_path } dataset ..." )
332+ for batch in data_loader :
333+ batch = validator .preprocess (batch )
334+ preds = method .execute ((batch ["img" ],))
335+ preds = validator .postprocess (preds )
336+ validator .update_metrics (preds , batch )
337+ stats = validator .get_stats ()
338+ return stats
339+
270340
271341if __name__ == "__main__" :
272342 parser = argparse .ArgumentParser (
@@ -312,6 +382,13 @@ def transform_fn(frame):
312382 default = "CPU" ,
313383 help = "Target device for compiling the model (e.g., CPU, GPU). Default is CPU." ,
314384 )
385+ parser .add_argument (
386+ "--validate" ,
387+ nargs = "?" ,
388+ const = "coco128.yaml" ,
389+ help = "Validate executorch model using the Ultralytics validation pipeline."
390+ " Default validateion dataset is coco128.yaml." ,
391+ )
315392
316393 args = parser .parse_args ()
317394
@@ -320,6 +397,7 @@ def transform_fn(frame):
320397 model_name = args .model_name ,
321398 input_dims = args .input_dims ,
322399 quantize = args .quantize ,
400+ val_dataset_yaml_path = args .validate ,
323401 video_path = args .video_path ,
324402 subset_size = args .subset_size ,
325403 backend = args .backend ,
0 commit comments