@@ -105,7 +105,7 @@ def load_calibration_dataset(
105105
106106def infer_model (
107107 exec_prog : EdgeProgramManager ,
108- input_shape ,
108+ inputs ,
109109 num_iter : int ,
110110 warmup_iter : int ,
111111 input_path : str ,
@@ -115,7 +115,7 @@ def infer_model(
115115 Executes inference and reports the average timing.
116116
117117 :param exec_prog: EdgeProgramManager of the lowered model
118- :param input_shape : The input shape for the model.
118+ :param inputs : The inputs for the model.
119119 :param num_iter: The number of iterations to execute inference for timing.
120120 :param warmup_iter: The number of iterations to execute inference for warmup before timing.
121121 :param input_path: Path to the input tensor file to read the input for inference.
@@ -128,8 +128,6 @@ def infer_model(
128128 # 2: Initialize inputs
129129 if input_path :
130130 inputs = (torch .load (input_path , weights_only = False ),)
131- else :
132- inputs = (torch .randn (input_shape ),)
133131
134132 # 3: Execute warmup
135133 for _i in range (warmup_iter ):
@@ -232,7 +230,14 @@ def main( # noqa: C901
232230 msg = "Input shape must be a list or tuple."
233231 raise ValueError (msg )
234232 # Provide input
235- example_args = (torch .randn (* input_shape ),)
233+ if suite == "huggingface" :
234+ if hasattr (model , 'config' ) and hasattr (model .config , 'vocab_size' ):
235+ vocab_size = model .config .vocab_size
236+ else :
237+ vocab_size = 30522
238+ example_args = (torch .randint (0 , vocab_size , input_shape , dtype = torch .int64 ), )
239+ else :
240+ example_args = (torch .randn (* input_shape ),)
236241
237242 # Export the model to the aten dialect
238243 aten_dialect : ExportedProgram = export (model , example_args )
@@ -301,7 +306,7 @@ def transform_fn(x):
301306 if infer :
302307 print ("Start inference of the model:" )
303308 avg_time = infer_model (
304- exec_prog , input_shape , num_iter , warmup_iter , input_path , output_path
309+ exec_prog , example_args , num_iter , warmup_iter , input_path , output_path
305310 )
306311 print (f"Average inference time: { avg_time } " )
307312
0 commit comments