Skip to content

Commit ca3ec9c

Browse files
committed
Fix inputs for hf models
1 parent 69c1902 commit ca3ec9c

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

backends/openvino/runtime/OpenvinoBackend.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ ov::element::Type OpenvinoBackend::convert_to_openvino_type(
170170
return ov::element::i32;
171171
case exa::ScalarType::Char:
172172
return ov::element::i8;
173+
case exa::ScalarType::Long:
174+
return ov::element::i64;
175+
case exa::ScalarType::Bool:
176+
return ov::element::boolean;
173177
default:
174178
throw std::runtime_error("Unsupported scalar type");
175179
}

examples/openvino/aot_optimize_and_infer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def load_calibration_dataset(
105105

106106
def infer_model(
107107
exec_prog: EdgeProgramManager,
108-
input_shape,
108+
inputs,
109109
num_iter: int,
110110
warmup_iter: int,
111111
input_path: str,
@@ -115,7 +115,7 @@ def infer_model(
115115
Executes inference and reports the average timing.
116116
117117
:param exec_prog: EdgeProgramManager of the lowered model
118-
:param input_shape: The input shape for the model.
118+
:param inputs: The inputs for the model.
119119
:param num_iter: The number of iterations to execute inference for timing.
120120
:param warmup_iter: The number of iterations to execute inference for warmup before timing.
121121
:param input_path: Path to the input tensor file to read the input for inference.
@@ -128,8 +128,6 @@ def infer_model(
128128
# 2: Initialize inputs
129129
if input_path:
130130
inputs = (torch.load(input_path, weights_only=False),)
131-
else:
132-
inputs = (torch.randn(input_shape),)
133131

134132
# 3: Execute warmup
135133
for _i in range(warmup_iter):
@@ -232,7 +230,14 @@ def main( # noqa: C901
232230
msg = "Input shape must be a list or tuple."
233231
raise ValueError(msg)
234232
# Provide input
235-
example_args = (torch.randn(*input_shape),)
233+
if suite == "huggingface":
234+
if hasattr(model, 'config') and hasattr(model.config, 'vocab_size'):
235+
vocab_size = model.config.vocab_size
236+
else:
237+
vocab_size = 30522
238+
example_args = (torch.randint(0, vocab_size, input_shape, dtype=torch.int64), )
239+
else:
240+
example_args = (torch.randn(*input_shape),)
236241

237242
# Export the model to the aten dialect
238243
aten_dialect: ExportedProgram = export(model, example_args)
@@ -301,7 +306,7 @@ def transform_fn(x):
301306
if infer:
302307
print("Start inference of the model:")
303308
avg_time = infer_model(
304-
exec_prog, input_shape, num_iter, warmup_iter, input_path, output_path
309+
exec_prog, example_args, num_iter, warmup_iter, input_path, output_path
305310
)
306311
print(f"Average inference time: {avg_time}")
307312

0 commit comments

Comments
 (0)