Skip to content

Commit 9104dc8

Browse files
Add hack for legacy keras (#1486)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 98e6143 commit 9104dc8

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

tf2onnx/convert.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,20 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
309309

310310
# let tensorflow do the checking if model is a valid model
311311
function = _saving_utils.trace_model_call(model, input_signature)
312-
concrete_func = function.get_concrete_function()
312+
try:
313+
concrete_func = function.get_concrete_function()
314+
except TypeError as e:
315+
# Legacy keras models don't accept the training arg tf provides so we hack around it
316+
if "got an unexpected keyword argument 'training'" not in str(e):
317+
raise e
318+
model_call = model.call
319+
def wrap_call(*args, training=False, **kwargs):
320+
return model_call(*args, **kwargs)
321+
model.call = wrap_call
322+
function = _saving_utils.trace_model_call(model, input_signature)
323+
concrete_func = function.get_concrete_function()
324+
# Put it back
325+
model.call = model_call
313326

314327
# These inputs will be removed during freezing (includes resources, etc.)
315328
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access

0 commit comments

Comments
 (0)