@@ -273,8 +273,7 @@ def main():
273
273
274
274
def tensor_names_from_structed (concrete_func , input_names , output_names ):
275
275
tensors_to_rename = {}
276
- args , kwargs = concrete_func .structured_input_signature
277
- structured_inputs = [t .name for t in args if isinstance (t , tf .TensorSpec )] + sorted (kwargs .keys ())
276
+ structured_inputs = [t .name for t in tf .nest .flatten (concrete_func .structured_input_signature )]
278
277
tensors_to_rename .update (zip (input_names , structured_inputs ))
279
278
if isinstance (concrete_func .structured_outputs , dict ):
280
279
for k , v in concrete_func .structured_outputs .items ():
@@ -306,17 +305,17 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
306
305
if LooseVersion (tf .__version__ ) < "2.0" :
307
306
raise NotImplementedError ("from_keras requires tf-2.0 or newer" )
308
307
309
- if not input_signature :
310
- raise ValueError ("from_keras requires input_signature" )
311
-
312
308
from tensorflow .python .keras .saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel
313
309
314
310
# let tensorflow do the checking if model is a valid model
315
311
function = _saving_utils .trace_model_call (model , input_signature )
316
- concrete_func = function .get_concrete_function (* input_signature )
312
+ concrete_func = function .get_concrete_function ()
317
313
314
+ # These inputs will be removed during freezing (includes resources, etc.)
315
+ graph_captures = concrete_func .graph ._captures # pylint: disable=protected-access
316
+ captured_inputs = [t_name .name for t_val , t_name in graph_captures .values ()]
318
317
input_names = [input_tensor .name for input_tensor in concrete_func .inputs
319
- if input_tensor .dtype != tf . dtypes . resource ]
318
+ if input_tensor .name not in captured_inputs ]
320
319
output_names = [output_tensor .name for output_tensor in concrete_func .outputs
321
320
if output_tensor .dtype != tf .dtypes .resource ]
322
321
0 commit comments