@@ -108,7 +108,6 @@ def infer_model(
108108 inputs ,
109109 num_iter : int ,
110110 warmup_iter : int ,
111- input_path : str ,
112111 output_path : str ,
113112) -> float :
114113 """
@@ -118,34 +117,29 @@ def infer_model(
118117 :param inputs: The inputs for the model.
119118 :param num_iter: The number of iterations to execute inference for timing.
120119 :param warmup_iter: The number of iterations to execute inference for warmup before timing.
121- :param input_path: Path to the input tensor file to read the input for inference.
122120 :param output_path: Path to the output tensor file to save the output of inference..
123121 :return: The average inference timing.
124122 """
125- # 1: Load model from buffer
123+ # Load model from buffer
126124 executorch_module = _load_for_executorch_from_buffer (exec_prog .buffer )
127125
128- # 2: Initialize inputs
129- if input_path :
130- inputs = (torch .load (input_path , weights_only = False ),)
131-
132- # 3: Execute warmup
126+ # Execute warmup
133127 for _i in range (warmup_iter ):
134128 out = executorch_module .run_method ("forward" , inputs )
135129
136- # 4: Execute inference and measure timing
130+ # Execute inference and measure timing
137131 time_total = 0.0
138132 for _i in range (num_iter ):
139133 time_start = time .time ()
140134 out = executorch_module .run_method ("forward" , inputs )
141135 time_end = time .time ()
142136 time_total += time_end - time_start
143137
144- # 5: Save output tensor as raw tensor file
138+ # Save output tensor as raw tensor file
145139 if output_path :
146140 torch .save (out , output_path )
147141
148- # 6: Return average inference timing
142+ # Return average inference timing
149143 return time_total / float (num_iter )
150144
151145
@@ -159,10 +153,10 @@ def validate_model(
159153 :param calibration_dataset: A DataLoader containing calibration data.
160154 :return: The accuracy score of the model.
161155 """
162- # 1: Load model from buffer
156+ # Load model from buffer
163157 executorch_module = _load_for_executorch_from_buffer (exec_prog .buffer )
164158
165- # 2: Iterate over the dataset and run the executor
159+ # Iterate over the dataset and run the executor
166160 predictions = []
167161 targets = []
168162 for _idx , data in enumerate (calibration_dataset ):
@@ -171,7 +165,7 @@ def validate_model(
171165 out = executorch_module .run_method ("forward" , (feature ,))
172166 predictions .extend (torch .stack (out ).reshape (- 1 , 1000 ).argmax (- 1 ))
173167
174- # 1: Check accuracy
168+ # Check accuracy
175169 return accuracy_score (predictions , targets )
176170
177171
@@ -230,12 +224,14 @@ def main( # noqa: C901
230224 msg = "Input shape must be a list or tuple."
231225 raise ValueError (msg )
232226 # Provide input
233- if suite == "huggingface" :
234- if hasattr (model , 'config' ) and hasattr (model .config , 'vocab_size' ):
227+ if input_path :
228+ example_args = (torch .load (input_path , weights_only = False ),)
229+ elif suite == "huggingface" :
230+ if hasattr (model , "config" ) and hasattr (model .config , "vocab_size" ):
235231 vocab_size = model .config .vocab_size
236232 else :
237233 vocab_size = 30522
238- example_args = (torch .randint (0 , vocab_size , input_shape , dtype = torch .int64 ), )
234+ example_args = (torch .randint (0 , vocab_size , input_shape , dtype = torch .int64 ),)
239235 else :
240236 example_args = (torch .randn (* input_shape ),)
241237
@@ -306,7 +302,7 @@ def transform_fn(x):
306302 if infer :
307303 print ("Start inference of the model:" )
308304 avg_time = infer_model (
309- exec_prog , example_args , num_iter , warmup_iter , input_path , output_path
305+ exec_prog , example_args , num_iter , warmup_iter , output_path
310306 )
311307 print (f"Average inference time: { avg_time } " )
312308
0 commit comments