Skip to content

Commit 82866db

Browse files
authored
Merge pull request #35 from ynimmaga/surya/fix_hf_models
Fix input for hf models
2 parents 69c1902 + ab0cb88 commit 82866db

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

backends/openvino/runtime/OpenvinoBackend.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ ov::element::Type OpenvinoBackend::convert_to_openvino_type(
170170
return ov::element::i32;
171171
case exa::ScalarType::Char:
172172
return ov::element::i8;
173+
case exa::ScalarType::Long:
174+
return ov::element::i64;
175+
case exa::ScalarType::Bool:
176+
return ov::element::boolean;
173177
default:
174178
throw std::runtime_error("Unsupported scalar type");
175179
}

examples/openvino/aot_optimize_and_infer.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,49 +105,41 @@ def load_calibration_dataset(
105105

106106
def infer_model(
107107
exec_prog: EdgeProgramManager,
108-
input_shape,
108+
inputs,
109109
num_iter: int,
110110
warmup_iter: int,
111-
input_path: str,
112111
output_path: str,
113112
) -> float:
114113
"""
115114
Executes inference and reports the average timing.
116115
117116
:param exec_prog: EdgeProgramManager of the lowered model
118-
:param input_shape: The input shape for the model.
117+
:param inputs: The inputs for the model.
119118
:param num_iter: The number of iterations to execute inference for timing.
120119
:param warmup_iter: The number of iterations to execute inference for warmup before timing.
121-
:param input_path: Path to the input tensor file to read the input for inference.
122120
:param output_path: Path to the output tensor file to save the output of inference..
123121
:return: The average inference timing.
124122
"""
125-
# 1: Load model from buffer
123+
# Load model from buffer
126124
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
127125

128-
# 2: Initialize inputs
129-
if input_path:
130-
inputs = (torch.load(input_path, weights_only=False),)
131-
else:
132-
inputs = (torch.randn(input_shape),)
133-
134-
# 3: Execute warmup
126+
# Execute warmup
135127
for _i in range(warmup_iter):
136128
out = executorch_module.run_method("forward", inputs)
137129

138-
# 4: Execute inference and measure timing
130+
# Execute inference and measure timing
139131
time_total = 0.0
140132
for _i in range(num_iter):
141133
time_start = time.time()
142134
out = executorch_module.run_method("forward", inputs)
143135
time_end = time.time()
144136
time_total += time_end - time_start
145137

146-
# 5: Save output tensor as raw tensor file
138+
# Save output tensor as raw tensor file
147139
if output_path:
148140
torch.save(out, output_path)
149141

150-
# 6: Return average inference timing
142+
# Return average inference timing
151143
return time_total / float(num_iter)
152144

153145

@@ -161,10 +153,10 @@ def validate_model(
161153
:param calibration_dataset: A DataLoader containing calibration data.
162154
:return: The accuracy score of the model.
163155
"""
164-
# 1: Load model from buffer
156+
# Load model from buffer
165157
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
166158

167-
# 2: Iterate over the dataset and run the executor
159+
# Iterate over the dataset and run the executor
168160
predictions = []
169161
targets = []
170162
for _idx, data in enumerate(calibration_dataset):
@@ -173,7 +165,7 @@ def validate_model(
173165
out = executorch_module.run_method("forward", (feature,))
174166
predictions.extend(torch.stack(out).reshape(-1, 1000).argmax(-1))
175167

176-
# 1: Check accuracy
168+
# Check accuracy
177169
return accuracy_score(predictions, targets)
178170

179171

@@ -232,7 +224,16 @@ def main( # noqa: C901
232224
msg = "Input shape must be a list or tuple."
233225
raise ValueError(msg)
234226
# Provide input
235-
example_args = (torch.randn(*input_shape),)
227+
if input_path:
228+
example_args = (torch.load(input_path, weights_only=False),)
229+
elif suite == "huggingface":
230+
if hasattr(model, "config") and hasattr(model.config, "vocab_size"):
231+
vocab_size = model.config.vocab_size
232+
else:
233+
vocab_size = 30522
234+
example_args = (torch.randint(0, vocab_size, input_shape, dtype=torch.int64),)
235+
else:
236+
example_args = (torch.randn(*input_shape),)
236237

237238
# Export the model to the aten dialect
238239
aten_dialect: ExportedProgram = export(model, example_args)
@@ -301,7 +302,7 @@ def transform_fn(x):
301302
if infer:
302303
print("Start inference of the model:")
303304
avg_time = infer_model(
304-
exec_prog, input_shape, num_iter, warmup_iter, input_path, output_path
305+
exec_prog, example_args, num_iter, warmup_iter, output_path
305306
)
306307
print(f"Average inference time: {avg_time}")
307308

0 commit comments

Comments
 (0)