@@ -181,6 +181,8 @@ def main(
181181 suite : str ,
182182 model_name : str ,
183183 input_shape ,
184+ save_model : bool ,
185+ model_file_name : str ,
184186 quantize : bool ,
185187 validate : bool ,
186188 dataset_path : str ,
@@ -198,6 +200,8 @@ def main(
198200 :param suite: The model suite to use (e.g., "timm", "torchvision", "huggingface").
199201 :param model_name: The name of the model to load.
200202 :param input_shape: The input shape for the model.
203+ :param save_model: Whether to save the compiled model as a .pte file.
204+ :param model_file_name: Custom file name to save the exported model.
201205 :param quantize: Whether to quantize the model.
202206 :param validate: Whether to validate the model.
203207 :param dataset_path: Path to the dataset for calibration/validation.
@@ -264,10 +268,12 @@ def main(
264268 )
265269
266270 # Serialize and save it to a file
267- model_file_name = f"{ model_name } _{ 'int8' if quantize else 'fp32' } .pte"
268- with open (model_file_name , "wb" ) as file :
269- exec_prog .write_to_file (file )
270- print (f"Model exported and saved as { model_file_name } on { device } ." )
271+ if save_model :
272+ if not model_file_name :
273+ model_file_name = f"{ model_name } _{ 'int8' if quantize else 'fp32' } .pte"
274+ with open (model_file_name , "wb" ) as file :
275+ exec_prog .write_to_file (file )
276+ print (f"Model exported and saved as { model_file_name } on { device } ." )
271277
272278 if validate :
273279 if suite == "huggingface" :
@@ -315,6 +321,14 @@ def main(
315321 help = "Batch size for the validation. Default batch_size == 1."
316322 " The dataset length must be evenly divisible by the batch size." ,
317323 )
324+ parser .add_argument (
325+ "--export" , action = "store_true" , help = "Export the compiled model as .pte file."
326+ )
327+ parser .add_argument (
328+ "--model_file_name" ,
329+ type = str ,
330+ help = "Custom file name to save the exported model." ,
331+ )
318332 parser .add_argument (
319333 "--quantize" , action = "store_true" , help = "Enable model quantization."
320334 )
@@ -367,6 +381,8 @@ def main(
367381 args .suite ,
368382 args .model ,
369383 args .input_shape ,
384+ args .export ,
385+ args .model_file_name ,
370386 args .quantize ,
371387 args .validate ,
372388 args .dataset ,
0 commit comments