Skip to content

Commit eac6b9f

Browse files
Improve keras API (#1477)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent ef8f7f3 commit eac6b9f

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

tf2onnx/convert.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,7 @@ def main():
273273

274274
def tensor_names_from_structed(concrete_func, input_names, output_names):
275275
tensors_to_rename = {}
276-
args, kwargs = concrete_func.structured_input_signature
277-
structured_inputs = [t.name for t in args if isinstance(t, tf.TensorSpec)] + sorted(kwargs.keys())
276+
structured_inputs = [t.name for t in tf.nest.flatten(concrete_func.structured_input_signature)]
278277
tensors_to_rename.update(zip(input_names, structured_inputs))
279278
if isinstance(concrete_func.structured_outputs, dict):
280279
for k, v in concrete_func.structured_outputs.items():
@@ -306,17 +305,17 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
306305
if LooseVersion(tf.__version__) < "2.0":
307306
raise NotImplementedError("from_keras requires tf-2.0 or newer")
308307

309-
if not input_signature:
310-
raise ValueError("from_keras requires input_signature")
311-
312308
from tensorflow.python.keras.saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel
313309

314310
# let tensorflow do the checking if model is a valid model
315311
function = _saving_utils.trace_model_call(model, input_signature)
316-
concrete_func = function.get_concrete_function(*input_signature)
312+
concrete_func = function.get_concrete_function()
317313

314+
# These inputs will be removed during freezing (includes resources, etc.)
315+
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
316+
captured_inputs = [t_name.name for t_val, t_name in graph_captures.values()]
318317
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
319-
if input_tensor.dtype != tf.dtypes.resource]
318+
if input_tensor.name not in captured_inputs]
320319
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
321320
if output_tensor.dtype != tf.dtypes.resource]
322321

0 commit comments

Comments
 (0)