Skip to content

Commit ab0cb88

Browse files
committed
Addressed PR comments
1 parent ca3ec9c commit ab0cb88

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

examples/openvino/aot_optimize_and_infer.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def infer_model(
108108
inputs,
109109
num_iter: int,
110110
warmup_iter: int,
111-
input_path: str,
112111
output_path: str,
113112
) -> float:
114113
"""
@@ -118,34 +117,29 @@ def infer_model(
118117
:param inputs: The inputs for the model.
119118
:param num_iter: The number of iterations to execute inference for timing.
120119
:param warmup_iter: The number of iterations to execute inference for warmup before timing.
121-
:param input_path: Path to the input tensor file to read the input for inference.
122120
:param output_path: Path to the output tensor file to save the output of inference..
123121
:return: The average inference timing.
124122
"""
125-
# 1: Load model from buffer
123+
# Load model from buffer
126124
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
127125

128-
# 2: Initialize inputs
129-
if input_path:
130-
inputs = (torch.load(input_path, weights_only=False),)
131-
132-
# 3: Execute warmup
126+
# Execute warmup
133127
for _i in range(warmup_iter):
134128
out = executorch_module.run_method("forward", inputs)
135129

136-
# 4: Execute inference and measure timing
130+
# Execute inference and measure timing
137131
time_total = 0.0
138132
for _i in range(num_iter):
139133
time_start = time.time()
140134
out = executorch_module.run_method("forward", inputs)
141135
time_end = time.time()
142136
time_total += time_end - time_start
143137

144-
# 5: Save output tensor as raw tensor file
138+
# Save output tensor as raw tensor file
145139
if output_path:
146140
torch.save(out, output_path)
147141

148-
# 6: Return average inference timing
142+
# Return average inference timing
149143
return time_total / float(num_iter)
150144

151145

@@ -159,10 +153,10 @@ def validate_model(
159153
:param calibration_dataset: A DataLoader containing calibration data.
160154
:return: The accuracy score of the model.
161155
"""
162-
# 1: Load model from buffer
156+
# Load model from buffer
163157
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
164158

165-
# 2: Iterate over the dataset and run the executor
159+
# Iterate over the dataset and run the executor
166160
predictions = []
167161
targets = []
168162
for _idx, data in enumerate(calibration_dataset):
@@ -171,7 +165,7 @@ def validate_model(
171165
out = executorch_module.run_method("forward", (feature,))
172166
predictions.extend(torch.stack(out).reshape(-1, 1000).argmax(-1))
173167

174-
# 1: Check accuracy
168+
# Check accuracy
175169
return accuracy_score(predictions, targets)
176170

177171

@@ -230,12 +224,14 @@ def main( # noqa: C901
230224
msg = "Input shape must be a list or tuple."
231225
raise ValueError(msg)
232226
# Provide input
233-
if suite == "huggingface":
234-
if hasattr(model, 'config') and hasattr(model.config, 'vocab_size'):
227+
if input_path:
228+
example_args = (torch.load(input_path, weights_only=False),)
229+
elif suite == "huggingface":
230+
if hasattr(model, "config") and hasattr(model.config, "vocab_size"):
235231
vocab_size = model.config.vocab_size
236232
else:
237233
vocab_size = 30522
238-
example_args = (torch.randint(0, vocab_size, input_shape, dtype=torch.int64), )
234+
example_args = (torch.randint(0, vocab_size, input_shape, dtype=torch.int64),)
239235
else:
240236
example_args = (torch.randn(*input_shape),)
241237

@@ -306,7 +302,7 @@ def transform_fn(x):
306302
if infer:
307303
print("Start inference of the model:")
308304
avg_time = infer_model(
309-
exec_prog, example_args, num_iter, warmup_iter, input_path, output_path
305+
exec_prog, example_args, num_iter, warmup_iter, output_path
310306
)
311307
print(f"Average inference time: {avg_time}")
312308

0 commit comments

Comments
 (0)