@@ -114,6 +114,7 @@ def parameterize_inputs(inputs, prefix=""):
114
114
inputs = tree .map_structure (make_tf_tensor_spec , input_signature )
115
115
decorated_fn = get_concrete_fn (model , inputs , ** kwargs )
116
116
ov_model = ov .convert_model (decorated_fn )
117
+ set_names (ov_model , inputs )
117
118
elif backend .backend () == "torch" :
118
119
import torch
119
120
@@ -128,6 +129,7 @@ def parameterize_inputs(inputs, prefix=""):
128
129
warnings .filterwarnings ("ignore" , category = torch .jit .TracerWarning )
129
130
traced = torch .jit .trace (model , sample_inputs )
130
131
ov_model = ov .convert_model (traced )
132
+ set_names (ov_model , sample_inputs )
131
133
else :
132
134
raise NotImplementedError (
133
135
"`export_openvino` is only compatible with OpenVINO, "
@@ -140,6 +142,30 @@ def parameterize_inputs(inputs, prefix=""):
140
142
io_utils .print_msg (f"Saved OpenVINO IR at '{ filepath } '." )
141
143
142
144
145
+ def collect_names (structure ):
146
+ if isinstance (structure , dict ):
147
+ for k , v in structure .items ():
148
+ if isinstance (v , (dict , list , tuple )):
149
+ yield from collect_names (v )
150
+ else :
151
+ yield k
152
+ elif isinstance (structure , (list , tuple )):
153
+ for v in structure :
154
+ yield from collect_names (v )
155
+ else :
156
+ if hasattr (structure , "name" ) and structure .name :
157
+ yield structure .name
158
+ else :
159
+ yield "input"
160
+
161
+
162
+ def set_names (model , inputs ):
163
+ names = list (collect_names (inputs ))
164
+ for ov_input , name in zip (model .inputs , names ):
165
+ ov_input .get_node ().set_friendly_name (name )
166
+ ov_input .tensor .set_names ({name })
167
+
168
+
143
169
def _check_jax_kwargs (kwargs ):
144
170
kwargs = kwargs .copy ()
145
171
if "is_static" not in kwargs :
0 commit comments