Skip to content

Commit 4085046

Browse files
[OpenVINO backend] fix openvino model exported names to match keras names (#21526)
* [OpenVINO backend] fix openvino model exported names to match keras names * Update keras/src/export/openvino.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras/src/export/openvino.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 6bc6203 commit 4085046

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

keras/src/export/openvino.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def parameterize_inputs(inputs, prefix=""):
114114
inputs = tree.map_structure(make_tf_tensor_spec, input_signature)
115115
decorated_fn = get_concrete_fn(model, inputs, **kwargs)
116116
ov_model = ov.convert_model(decorated_fn)
117+
set_names(ov_model, inputs)
117118
elif backend.backend() == "torch":
118119
import torch
119120

@@ -128,6 +129,7 @@ def parameterize_inputs(inputs, prefix=""):
128129
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
129130
traced = torch.jit.trace(model, sample_inputs)
130131
ov_model = ov.convert_model(traced)
132+
set_names(ov_model, sample_inputs)
131133
else:
132134
raise NotImplementedError(
133135
"`export_openvino` is only compatible with OpenVINO, "
@@ -140,6 +142,30 @@ def parameterize_inputs(inputs, prefix=""):
140142
io_utils.print_msg(f"Saved OpenVINO IR at '{filepath}'.")
141143

142144

145+
def collect_names(structure):
146+
if isinstance(structure, dict):
147+
for k, v in structure.items():
148+
if isinstance(v, (dict, list, tuple)):
149+
yield from collect_names(v)
150+
else:
151+
yield k
152+
elif isinstance(structure, (list, tuple)):
153+
for v in structure:
154+
yield from collect_names(v)
155+
else:
156+
if hasattr(structure, "name") and structure.name:
157+
yield structure.name
158+
else:
159+
yield "input"
160+
161+
162+
def set_names(model, inputs):
163+
names = list(collect_names(inputs))
164+
for ov_input, name in zip(model.inputs, names):
165+
ov_input.get_node().set_friendly_name(name)
166+
ov_input.tensor.set_names({name})
167+
168+
143169
def _check_jax_kwargs(kwargs):
144170
kwargs = kwargs.copy()
145171
if "is_static" not in kwargs:

0 commit comments

Comments
 (0)