@@ -105,49 +105,41 @@ def load_calibration_dataset(
105105
106106def infer_model (
107107 exec_prog : EdgeProgramManager ,
108- input_shape ,
108+ inputs ,
109109 num_iter : int ,
110110 warmup_iter : int ,
111- input_path : str ,
112111 output_path : str ,
113112) -> float :
114113 """
115114 Executes inference and reports the average timing.
116115
117116 :param exec_prog: EdgeProgramManager of the lowered model
118- :param input_shape : The input shape for the model.
117+ :param inputs : The inputs for the model.
119118 :param num_iter: The number of iterations to execute inference for timing.
120119 :param warmup_iter: The number of iterations to execute inference for warmup before timing.
121- :param input_path: Path to the input tensor file to read the input for inference.
122120 :param output_path: Path to the output tensor file to save the output of inference..
123121 :return: The average inference timing.
124122 """
125- # 1: Load model from buffer
123+ # Load model from buffer
126124 executorch_module = _load_for_executorch_from_buffer (exec_prog .buffer )
127125
128- # 2: Initialize inputs
129- if input_path :
130- inputs = (torch .load (input_path , weights_only = False ),)
131- else :
132- inputs = (torch .randn (input_shape ),)
133-
134- # 3: Execute warmup
126+ # Execute warmup
135127 for _i in range (warmup_iter ):
136128 out = executorch_module .run_method ("forward" , inputs )
137129
138- # 4: Execute inference and measure timing
130+ # Execute inference and measure timing
139131 time_total = 0.0
140132 for _i in range (num_iter ):
141133 time_start = time .time ()
142134 out = executorch_module .run_method ("forward" , inputs )
143135 time_end = time .time ()
144136 time_total += time_end - time_start
145137
146- # 5: Save output tensor as raw tensor file
138+ # Save output tensor as raw tensor file
147139 if output_path :
148140 torch .save (out , output_path )
149141
150- # 6: Return average inference timing
142+ # Return average inference timing
151143 return time_total / float (num_iter )
152144
153145
@@ -161,10 +153,10 @@ def validate_model(
161153 :param calibration_dataset: A DataLoader containing calibration data.
162154 :return: The accuracy score of the model.
163155 """
164- # 1: Load model from buffer
156+ # Load model from buffer
165157 executorch_module = _load_for_executorch_from_buffer (exec_prog .buffer )
166158
167- # 2: Iterate over the dataset and run the executor
159+ # Iterate over the dataset and run the executor
168160 predictions = []
169161 targets = []
170162 for _idx , data in enumerate (calibration_dataset ):
@@ -173,7 +165,7 @@ def validate_model(
173165 out = executorch_module .run_method ("forward" , (feature ,))
174166 predictions .extend (torch .stack (out ).reshape (- 1 , 1000 ).argmax (- 1 ))
175167
176- # 1: Check accuracy
168+ # Check accuracy
177169 return accuracy_score (predictions , targets )
178170
179171
@@ -232,7 +224,16 @@ def main( # noqa: C901
232224 msg = "Input shape must be a list or tuple."
233225 raise ValueError (msg )
234226 # Provide input
235- example_args = (torch .randn (* input_shape ),)
227+ if input_path :
228+ example_args = (torch .load (input_path , weights_only = False ),)
229+ elif suite == "huggingface" :
230+ if hasattr (model , "config" ) and hasattr (model .config , "vocab_size" ):
231+ vocab_size = model .config .vocab_size
232+ else :
233+ vocab_size = 30522
234+ example_args = (torch .randint (0 , vocab_size , input_shape , dtype = torch .int64 ),)
235+ else :
236+ example_args = (torch .randn (* input_shape ),)
236237
237238 # Export the model to the aten dialect
238239 aten_dialect : ExportedProgram = export (model , example_args )
@@ -301,7 +302,7 @@ def transform_fn(x):
301302 if infer :
302303 print ("Start inference of the model:" )
303304 avg_time = infer_model (
304- exec_prog , input_shape , num_iter , warmup_iter , input_path , output_path
305+ exec_prog , example_args , num_iter , warmup_iter , output_path
305306 )
306307 print (f"Average inference time: { avg_time } " )
307308
0 commit comments