55# LICENSE file in the root directory of this source tree.
66
77import argparse
8+ import time
89
910import executorch
1011
@@ -102,6 +103,54 @@ def load_calibration_dataset(
102103 return calibration_dataset
103104
104105
106+ def infer_model (
107+ exec_prog : EdgeProgramManager ,
108+ input_shape ,
109+ num_iter : int ,
110+ warmup_iter : int ,
111+ input_path : str ,
112+ output_path : str ,
113+ ) -> float :
114+ """
115+ Executes inference and reports the average timing.
116+
117+ :param exec_prog: EdgeProgramManager of the lowered model
118+ :param input_shape: The input shape for the model.
119+ :param num_iter: The number of iterations to execute inference for timing.
120+ :param warmup_iter: The number of iterations to execute inference for warmup before timing.
121+ :param input_path: Path to the input tensor file to read the input for inference.
122+ :param output_path: Path to the output tensor file to save the output of inference..
123+ :return: The average inference timing.
124+ """
125+ # 1: Load model from buffer
126+ executorch_module = _load_for_executorch_from_buffer (exec_prog .buffer )
127+
128+ # 2: Initialize inputs
129+ if input_path :
130+ inputs = (torch .load (input_path , weights_only = False ),)
131+ else :
132+ inputs = (torch .randn (input_shape ),)
133+
134+ # 3: Execute warmup
135+ for _i in range (warmup_iter ):
136+ out = executorch_module .run_method ("forward" , inputs )
137+
138+ # 4: Execute inference and measure timing
139+ time_total = 0.0
140+ for _i in range (num_iter ):
141+ time_start = time .time ()
142+ out = executorch_module .run_method ("forward" , inputs )
143+ time_end = time .time ()
144+ time_total += time_end - time_start
145+
146+ # 5: Save output tensor as raw tensor file
147+ if output_path :
148+ torch .save (out , output_path )
149+
150+ # 6: Return average inference timing
151+ return time_total / float (num_iter )
152+
153+
105154def validate_model (
106155 exec_prog : EdgeProgramManager , calibration_dataset : torch .utils .data .DataLoader
107156) -> float :
@@ -137,6 +186,11 @@ def main(
137186 dataset_path : str ,
138187 device : str ,
139188 batch_size : int ,
189+ infer : bool ,
190+ num_iter : int ,
191+ warmup_iter : int ,
192+ input_path : str ,
193+ output_path : str ,
140194):
141195 """
142196 Main function to load, quantize, and validate a model.
@@ -149,6 +203,12 @@ def main(
149203 :param dataset_path: Path to the dataset for calibration/validation.
150204 :param device: The device to run the model on (e.g., "cpu", "gpu").
151205 :param batch_size: Batch size for dataset loading.
206+ :param infer: Whether to execute inference and report timing.
207+ :param num_iter: The number of iterations to execute inference for timing.
208+ :param warmup_iter: The number of iterations to execute inference for warmup before timing.
209+ :param input_path: Path to the input tensor file to read the input for inference.
210+ :param output_path: Path to the output tensor file to save the output of inference..
211+
152212 """
153213
154214 # Load the selected model
@@ -222,6 +282,13 @@ def main(
222282 acc_top1 = validate_model (exec_prog , calibration_dataset )
223283 print (f"acc@1: { acc_top1 } " )
224284
285+ if infer :
286+ print ("Start inference of the model:" )
287+ avg_time = infer_model (
288+ exec_prog , input_shape , num_iter , warmup_iter , input_path , output_path
289+ )
290+ print (f"Average inference time: { avg_time } " )
291+
225292
226293if __name__ == "__main__" :
227294 # Argument parser for dynamic inputs
@@ -256,6 +323,33 @@ def main(
256323 action = "store_true" ,
257324 help = "Enable model validation. --dataset argument is required for the validation." ,
258325 )
326+ parser .add_argument (
327+ "--infer" ,
328+ action = "store_true" ,
329+ help = "Run inference and report timing." ,
330+ )
331+ parser .add_argument (
332+ "--num_iter" ,
333+ type = int ,
334+ default = 1 ,
335+ help = "The number of iterations to execute inference for timing." ,
336+ )
337+ parser .add_argument (
338+ "--warmup_iter" ,
339+ type = int ,
340+ default = 0 ,
341+ help = "The number of iterations to execute inference for warmup before timing." ,
342+ )
343+ parser .add_argument (
344+ "--input_tensor_path" ,
345+ type = str ,
346+ help = "Path to the input tensor file to read the input for inference." ,
347+ )
348+ parser .add_argument (
349+ "--output_tensor_path" ,
350+ type = str ,
351+ help = "Path to the output tensor file to save the output of inference." ,
352+ )
259353 parser .add_argument ("--dataset" , type = str , help = "Path to the validation dataset." )
260354 parser .add_argument (
261355 "--device" ,
@@ -278,4 +372,9 @@ def main(
278372 args .dataset ,
279373 args .device ,
280374 args .batch_size ,
375+ args .infer ,
376+ args .num_iter ,
377+ args .warmup_iter ,
378+ args .input_tensor_path ,
379+ args .output_tensor_path ,
281380 )
0 commit comments