Skip to content

Commit d963a72

Browse files
committed
Get model dtype from weight dtype
1 parent 5ac5e5d commit d963a72

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

optimum/exporters/openvino/convert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
_get_dynamic_shapes_info,
7171
_normalize_dummy_inputs,
7272
_get_open_clip_submodels_fn_and_export_configs,
73+
get_model_dtype,
7374
allow_skip_tracing_check,
7475
clear_class_registry,
7576
remove_none_from_dummy_inputs,
@@ -558,7 +559,7 @@ def patched_forward(*args, **kwargs):
558559
# patch_everywhere breaks torch.ops namespace
559560
del torch.ops._prepare_4d_causal_attention_mask_for_sdpa
560561
dynamic_shapes = _get_dynamic_shapes_info(model, config, dummy_inputs)
561-
_export_kwargs = {"args": tuple(), "kwargs": _normalize_dummy_inputs(dummy_inputs, model.dtype)}
562+
_export_kwargs = {"args": tuple(), "kwargs": _normalize_dummy_inputs(dummy_inputs, get_model_dtype(model))}
562563
_export_kwargs["dynamic_shapes"] = dynamic_shapes
563564

564565
try:

optimum/exporters/openvino/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ def _normalize_dummy_inputs(dummy_inputs: Dict[str, Any], dtype: Any) -> Dict[st
181181
return new_dummy
182182

183183

184+
def get_model_dtype(model):
185+
for param in model.parameters():
186+
return param.dtype
187+
return getattr(model, "dtype", torch.float32)
188+
189+
184190
def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]):
185191
"""
186192
Removes None values from the dictionary.

0 commit comments

Comments
 (0)