Skip to content

Commit 8a61c99

Browse files
add pure keras tests from keras2onnx (#1574)
* add pure keras tests from keras2onnx Signed-off-by: Tom Wildenhain <[email protected]> * Disable some failing pure keras tests for old tf versions Signed-off-by: Tom Wildenhain <[email protected]> * polish changes Signed-off-by: Tom Wildenhain <[email protected]>
1 parent c84b7c6 commit 8a61c99

File tree

5 files changed

+138
-14
lines changed

5 files changed

+138
-14
lines changed

ci_build/azure_pipelines/keras2onnx_unit_test.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ jobs:
88
strategy:
99
matrix:
1010

11+
############ TF Keras Unit Tests ############
1112
Python36-tf1.15:
1213
python.version: '3.6'
1314
ONNX_PATH: onnx==1.5.0
@@ -38,6 +39,36 @@ jobs:
3839
TENSORFLOW_PATH: tensorflow-cpu==2.5.0
3940
INSTALL_ORT: pip install onnxruntime==1.8.0
4041

42+
############ Pure Keras Unit Tests ############
43+
# Keras-Py36-tf1.15.0: # Failing, will enable soon.
44+
# python.version: '3.6'
45+
# ONNX_PATH: onnx==1.5.0
46+
# KERAS: keras==2.2.5
47+
# TENSORFLOW_PATH: tensorflow==1.15.0
48+
# INSTALL_ORT: pip install onnxruntime==1.8.0
49+
50+
Keras-Py37-tf1.15.0:
51+
python.version: '3.7'
52+
ONNX_PATH: onnx==1.9.0
53+
KERAS: keras==2.4.3
54+
TENSORFLOW_PATH: tensorflow==1.15.0
55+
INSTALL_ORT: pip install onnxruntime==1.8.0
56+
57+
# UT for keras 2.3 need tensorflow <= 2.0.0
58+
Keras-Py37-tf2.0.0:
59+
python.version: '3.7'
60+
ONNX_PATH: onnx==1.6.0
61+
KERAS: keras==2.3.1
62+
TENSORFLOW_PATH: tensorflow==2.0.0
63+
INSTALL_ORT: pip install onnxruntime==1.8.0
64+
65+
Keras-Py38-tf2.2.0:
66+
python.version: '3.8'
67+
ONNX_PATH: onnx==1.7.0
68+
KERAS: keras==2.4.3
69+
TENSORFLOW_PATH: tensorflow==2.2.0
70+
INSTALL_ORT: pip install onnxruntime==1.8.0
71+
4172
steps:
4273
- script: sudo install -d -m 0777 /home/vsts/.conda/envs
4374
displayName: Fix Conda permissions
@@ -55,6 +86,10 @@ jobs:
5586
pip install h5py==2.9.0
5687
pip install numpy==1.19
5788
pip install $(TENSORFLOW_PATH)
89+
if [[ ! -z $KERAS ]];
90+
then
91+
pip install $(KERAS)
92+
fi
5893
pip install git+https://github.com/microsoft/onnxconverter-common
5994
pip install -r requirements.txt
6095
pip install -r requirements-dev.txt
@@ -66,6 +101,10 @@ jobs:
66101
pip install -e .
67102
python -c "import onnxruntime"
68103
python -c "import onnxconverter_common"
104+
if [[ ! -z $KERAS ]];
105+
then
106+
export TF_KERAS=0
107+
fi
69108
pytest keras2onnx_tests --doctest-modules --junitxml=junit/test-results.xml
70109
displayName: 'pytest'
71110

keras2onnx_tests/test_cgan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tensorflow as tf
55
import mock_keras2onnx
66
import numpy as np
7-
from mock_keras2onnx.proto import keras, is_tf_keras
7+
from mock_keras2onnx.proto import keras, is_tf_keras, is_tensorflow_older_than
88
from tf2onnx.keras2onnx_api import convert_keras
99
from distutils.version import StrictVersion
1010

@@ -118,6 +118,8 @@ def build_discriminator(self):
118118
@pytest.mark.skipif(mock_keras2onnx.proto.tfcompat.is_tf2 and is_tf_keras, reason="Tensorflow 1.x only tests.")
119119
@pytest.mark.skipif(is_tf_keras and StrictVersion(tf.__version__.split('-')[0]) < StrictVersion("1.14.0"),
120120
reason="Not supported before tensorflow 1.14.0 for tf_keras")
121+
@pytest.mark.skipif(mock_keras2onnx.proto.tfcompat.is_tf2 and is_tensorflow_older_than('2.2'),
122+
reason="Variable freezing fails to replace ResourceGather op")
121123
def test_CGAN(runner):
122124
keras_model = CGAN().combined
123125
batch = 5

keras2onnx_tests/test_layers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import numpy as np
5-
from onnxconverter_common.onnx_ex import get_maximum_opset_supported
5+
from tf2onnx.keras2onnx_api import get_maximum_opset_supported
66
from mock_keras2onnx.proto.tfcompat import is_tf2, tensorflow as tf
77
from mock_keras2onnx.proto import (keras, is_tf_keras,
88
is_tensorflow_older_than, is_tensorflow_later_than,
@@ -1633,6 +1633,8 @@ def test_padding(misc_conv_runner):
16331633
misc_conv_runner(layer, ishape)
16341634

16351635

1636+
@pytest.mark.skipif(is_tf2 and is_tensorflow_older_than('2.2'),
1637+
reason="Variable freezing fails to replace ResourceGather op")
16361638
def test_embedding(runner):
16371639
model = keras.Sequential()
16381640
model.add(Embedding(1000, 64, input_length=10))
@@ -1853,6 +1855,8 @@ def test_GRU(runner):
18531855
assert runner(onnx_model.graph.name, onnx_model, [data, init_state_onnx], expected)
18541856

18551857

1858+
@pytest.mark.skipif(not is_tf_keras and is_tf2 and is_tensorflow_older_than('2.2'),
1859+
reason="Fails due to some reason involving bad graph captures. Works in new versions and tf_keras")
18561860
def test_GRU_2(runner):
18571861
model = keras.Sequential(name='TestGRU')
18581862
model.add(keras.layers.GRU(400, reset_after=True, input_shape=(1, 257)))
@@ -2109,6 +2113,8 @@ def test_bidirectional_with_initial_states(runner, rnn_class):
21092113
@pytest.mark.skipif(get_maximum_opset_supported() < 5,
21102114
reason="None seq_length Bidirectional LSTM is not supported before opset 5.")
21112115
@pytest.mark.parametrize("rnn_class", RNN_CLASSES)
2116+
@pytest.mark.skipif(is_tf2 and is_tensorflow_older_than('2.2'),
2117+
reason="Variable freezing fails to replace GatherResource op")
21122118
def test_bidirectional_seqlen_none(runner, rnn_class):
21132119
model = Sequential()
21142120
model.add(Embedding(39, 128))
@@ -2199,6 +2205,8 @@ def test_separable_convolution(runner):
21992205
assert runner('separable_convolution_2', onnx_model, x, expected)
22002206

22012207

2208+
@pytest.mark.skipif(is_tf2 and is_tensorflow_older_than('2.2'),
2209+
reason="Variable freezing fails to replace GatherResource op")
22022210
def test_shared_embed(runner):
22032211
max_cont_length = 5
22042212
max_ques_length = 7

tf2onnx/convert.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -291,19 +291,65 @@ def tensor_names_from_structed(concrete_func, input_names, output_names):
291291
return tensors_to_rename
292292

293293

294+
def _rename_duplicate_keras_model_names(model):
295+
"""
296+
In very rare cases, keras has a bug where it will give multiple outputs the same name.
297+
We must edit the model or the TF trace will fail. Returns old_out_names (or None if no edit was made).
298+
IMPORTANT: model may be edited. Assign model.output_names to old_out_names to restore.
299+
"""
300+
old_out_names = None
301+
if model.output_names and len(set(model.output_names)) != len(model.output_names):
302+
# In very rare cases, keras has a bug where it will give multiple outputs the same name
303+
# We must edit the model or the TF trace will fail
304+
old_out_names = model.output_names
305+
used_names = set()
306+
new_out_names = []
307+
for name in model.output_names:
308+
new_name = name
309+
i = 0
310+
while new_name in used_names:
311+
i += 1
312+
new_name = name + "_" + str(i)
313+
used_names.add(new_name)
314+
new_out_names.append(new_name)
315+
model.output_names = new_out_names
316+
return old_out_names
317+
318+
319+
def _is_legacy_keras_model(model):
320+
"""Inspects model class to determine if it is from tf or legacy keras"""
321+
322+
logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME)
323+
unknown_type_err = "model is not instance of tf.keras.Model or keras.Model"
324+
if isinstance(model, tf.keras.Model):
325+
return False
326+
try:
327+
import keras # pylint: disable=import-outside-toplevel
328+
if isinstance(model, keras.Model):
329+
return True
330+
logger.warning(unknown_type_err)
331+
except ImportError:
332+
logger.warning(unknown_type_err)
333+
return False
334+
335+
294336
def _from_keras_tf1(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
295337
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None,
296338
target=None, large_model=False, output_path=None):
297339
"""from_keras for tf 1.15"""
298-
299340
input_names = [t.name for t in model.inputs]
300341
output_names = [t.name for t in model.outputs]
342+
old_out_names = _rename_duplicate_keras_model_names(model)
301343
tensors_to_rename = dict(zip(input_names, model.input_names))
302-
if len(set(model.output_names)) == len(model.output_names):
303-
# In very rare cases, keras has a bug where it will give multiple outputs the same name
304-
tensors_to_rename.update(zip(output_names, model.output_names))
344+
tensors_to_rename.update(zip(output_names, model.output_names))
345+
if old_out_names is not None:
346+
model.output_names = old_out_names
305347

306-
sess = tf.keras.backend.get_session(model.outputs)
348+
if _is_legacy_keras_model(model):
349+
import keras # pylint: disable=import-outside-toplevel
350+
sess = keras.backend.get_session()
351+
else:
352+
sess = tf.keras.backend.get_session(model.outputs)
307353

308354
with tf.device("/cpu:0"):
309355
frozen_graph, initialized_tables = tf_loader.freeze_session(sess, input_names, output_names, get_tables=True)
@@ -351,6 +397,7 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
351397
Returns:
352398
An ONNX model_proto and an external_tensor_storage dict.
353399
"""
400+
old_out_names = _rename_duplicate_keras_model_names(model)
354401
if LooseVersion(tf.__version__) < "2.0":
355402
return _from_keras_tf1(model, input_signature, opset, custom_ops, custom_op_handlers, custom_rewriter,
356403
inputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)
@@ -370,9 +417,21 @@ def wrap_call(*args, training=False, **kwargs):
370417
return model_call(*args, **kwargs)
371418
model.call = wrap_call
372419
function = _saving_utils.trace_model_call(model, input_signature)
373-
concrete_func = function.get_concrete_function()
374-
# Put it back
375-
model.call = model_call
420+
try:
421+
# Legacy keras get make TF erroneously enter eager mode when it should be making symbolic tensors
422+
import tensorflow_core # pylint: disable=import-outside-toplevel
423+
old_get_learning_phase = tensorflow_core.python.keras.backend.learning_phase
424+
tensorflow_core.python.keras.backend.learning_phase = \
425+
tensorflow_core.python.keras.backend.symbolic_learning_phase
426+
except ImportError:
427+
old_get_learning_phase = None
428+
try:
429+
concrete_func = function.get_concrete_function()
430+
finally:
431+
# Put everything back
432+
model.call = model_call
433+
if old_get_learning_phase is not None:
434+
tensorflow_core.python.keras.backend.learning_phase = old_get_learning_phase
376435

377436
# These inputs will be removed during freezing (includes resources, etc.)
378437
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
@@ -392,6 +451,9 @@ def wrap_call(*args, training=False, **kwargs):
392451
# Other models specify output order using the key order of structured_outputs
393452
output_names = [reverse_lookup[out] for out in concrete_func.structured_outputs.keys()]
394453

454+
if old_out_names is not None:
455+
model.output_names = old_out_names
456+
395457
with tf.device("/cpu:0"):
396458
frozen_graph, initialized_tables = \
397459
tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model)

tf2onnx/tf_loader.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,19 @@ def inputs_without_resource(sess, input_names):
113113
def convert_variables_to_constants_large_model(func):
114114
# For large models we use internal tf methods as a hack
115115

116+
if tf.__version__.startswith("2.1.") or tf.__version__.startswith("2.0."):
117+
from tensorflow.python.framework import convert_to_constants
118+
orig_fn = convert_to_constants._construct_concrete_function # pylint: disable=protected-access
119+
def fake_construct_fn(func, output_graph_def, converted_input_indices):
120+
# Return graph_def without loading it to avoid crash. Will fix errors in graph_def later.
121+
return output_graph_def
122+
convert_to_constants._construct_concrete_function = fake_construct_fn # pylint: disable=protected-access
123+
try:
124+
frozen_graph_def = convert_to_constants.convert_variables_to_constants_v2(func, lower_control_flow=False)
125+
finally:
126+
convert_to_constants._construct_concrete_function = orig_fn # pylint: disable=protected-access
127+
return frozen_graph_def
128+
116129
if tf.__version__.startswith("2.2."):
117130
try:
118131
from tensorflow.python.framework.convert_to_constants import \
@@ -156,9 +169,9 @@ def make_tensor_proto_wrapped(values, dtype=None, shape=None, verify_shape=False
156169
def fix_freezing_errors(graph_def):
157170
assign_var_ops = []
158171
for i in reversed(range(len(graph_def.node))):
159-
if graph_def.node[i].op == "AssignVariableOp":
172+
if graph_def.node[i].op in ["AssignVariableOp", "AssignSubVariableOp"]:
160173
assign_var_ops.append(graph_def.node.pop(i).name)
161-
logger.warning("Removed AssignVariableOp %s", assign_var_ops[-1])
174+
logger.warning("Removed %s %s", graph_def.node[i].op, assign_var_ops[-1])
162175
names_to_remove = set(assign_var_ops)
163176
for n in graph_def.node:
164177
for i in reversed(range(len(n.input))):
@@ -218,9 +231,9 @@ def from_function(func, input_names, output_names, large_model=False):
218231
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False, aggressive_inlining=True)
219232
except ValueError as e:
220233
if "incompatible with expected resource" in str(e):
221-
frozen_func = convert_variables_to_constants_large_model(func)
234+
bad_graph_def = convert_variables_to_constants_large_model(func)
222235
logger.warning("TF freezing failed. Attempting to fix freezing errors.")
223-
graph_def = fix_freezing_errors(frozen_func)
236+
graph_def = fix_freezing_errors(bad_graph_def)
224237
else:
225238
raise e
226239
else:

0 commit comments

Comments
 (0)