Skip to content

Commit 6c417b4

Browse files
Added tests from keras2onnx (#1563)
* Added tests from keras2onnx Signed-off-by: Tom Wildenhain <[email protected]> * Add pytest to setup Signed-off-by: Tom Wildenhain <[email protected]> * Add keras2onnx api to tf2onnx Signed-off-by: Tom Wildenhain <[email protected]> * Mock keras2onnx utility methods Signed-off-by: Tom Wildenhain <[email protected]> * Update keras2onnx tests to use ORT 1.8.0 Signed-off-by: Tom Wildenhain <[email protected]> * Bugfix for Slice Signed-off-by: Tom Wildenhain <[email protected]> * Fix RU rewriter Signed-off-by: Tom Wildenhain <[email protected]> * Remove tf1.11 tests Signed-off-by: Tom Wildenhain <[email protected]> * Improve keras API documentation Signed-off-by: Tom Wildenhain <[email protected]>
1 parent d7ee792 commit 6c417b4

File tree

9 files changed

+3577
-0
lines changed

9 files changed

+3577
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Tests copied from keras2onnx
2+
3+
jobs:
4+
5+
- job: 'Test'
6+
pool:
7+
vmImage: 'Ubuntu-16.04'
8+
strategy:
9+
matrix:
10+
11+
Python36-tf1.15:
12+
python.version: '3.6'
13+
ONNX_PATH: onnx==1.5.0
14+
TENSORFLOW_PATH: tensorflow==1.15.0
15+
INSTALL_ORT: pip install onnxruntime==1.8.0
16+
17+
Python37-tf2.1:
18+
python.version: '3.7'
19+
ONNX_PATH: onnx==1.6.0
20+
TENSORFLOW_PATH: tensorflow-cpu==2.1.0
21+
INSTALL_ORT: pip install onnxruntime==1.8.0
22+
23+
Python38-tf2.2:
24+
python.version: '3.8'
25+
ONNX_PATH: onnx==1.7.0
26+
TENSORFLOW_PATH: tensorflow-cpu==2.2.0
27+
INSTALL_ORT: pip install onnxruntime==1.8.0
28+
29+
Python38-tf2.3:
30+
python.version: '3.8'
31+
ONNX_PATH: onnx==1.8.0
32+
TENSORFLOW_PATH: tensorflow-cpu==2.3.0
33+
INSTALL_ORT: pip install onnxruntime==1.8.0
34+
35+
Python38-tf2.5:
36+
python.version: '3.8'
37+
ONNX_PATH: onnx==1.8.0
38+
TENSORFLOW_PATH: tensorflow-cpu==2.5.0
39+
INSTALL_ORT: pip install onnxruntime==1.8.0
40+
41+
steps:
42+
- script: sudo install -d -m 0777 /home/vsts/.conda/envs
43+
displayName: Fix Conda permissions
44+
45+
- task: CondaEnvironment@1
46+
inputs:
47+
createCustomEnvironment: true
48+
environmentName: 'py$(python.version)'
49+
packageSpecs: 'python=$(python.version)'
50+
51+
- script: |
52+
python -m pip install --upgrade pip
53+
conda config --set always_yes yes --set changeps1 no
54+
pip install $(ONNX_PATH)
55+
pip install h5py==2.9.0
56+
pip install numpy==1.19
57+
pip install $(TENSORFLOW_PATH)
58+
pip install git+https://github.com/microsoft/onnxconverter-common
59+
pip install -r requirements.txt
60+
pip install -r requirements-dev.txt
61+
pip install pytest pytest-cov pytest-runner
62+
$(INSTALL_ORT)
63+
displayName: 'Install dependencies'
64+
65+
- script: |
66+
pip install -e .
67+
python -c "import onnxruntime"
68+
python -c "import onnxconverter_common"
69+
pytest keras2onnx_tests --doctest-modules --junitxml=junit/test-results.xml
70+
displayName: 'pytest'
71+
72+
- task: PublishTestResults@2
73+
inputs:
74+
testResultsFiles: '**/test-results.xml'
75+
testRunTitle: 'Python $(python.version)'
76+
condition: succeededOrFailed()

keras2onnx_tests/conftest.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import os
4+
import pytest
5+
6+
from mock_keras2onnx.proto import keras
7+
from test_utils import run_onnx_runtime
8+
9+
K = keras.backend
10+
11+
12+
@pytest.fixture(scope='function')
13+
def runner():
14+
model_files = []
15+
16+
def runner_func(*args, **kwargs):
17+
return run_onnx_runtime(*args, model_files, **kwargs)
18+
19+
# Ensure Keras layer naming is reset for each function
20+
K.reset_uids()
21+
# Reset the TensorFlow session to avoid resource leaking between tests
22+
K.clear_session()
23+
24+
# Provide wrapped run_onnx_runtime function
25+
yield runner_func
26+
27+
# Remove model files
28+
for fl in model_files:
29+
os.remove(fl)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import os
4+
import tensorflow
5+
from distutils.version import StrictVersion
6+
7+
# Rather than using ONNX protobuf definition throughout our codebase, we import ONNX protobuf definition here so that
8+
# we can conduct quick fixes by overwriting ONNX functions without changing any lines elsewhere.
9+
from onnx import onnx_pb as onnx_proto
10+
from onnx import helper
11+
from onnx import save_model as save_model
12+
13+
14+
def _check_onnx_version():
15+
import pkg_resources
16+
min_required_version = pkg_resources.parse_version('1.0.1')
17+
current_version = pkg_resources.get_distribution('onnx').parsed_version
18+
assert current_version >= min_required_version, 'Keras2ONNX requires ONNX version 1.0.1 or a newer one'
19+
20+
21+
_check_onnx_version()
22+
23+
24+
def is_tensorflow_older_than(version_str):
25+
return StrictVersion(tensorflow.__version__.split('-')[0]) < StrictVersion(version_str)
26+
27+
28+
def is_tensorflow_later_than(version_str):
29+
return StrictVersion(tensorflow.__version__.split('-')[0]) > StrictVersion(version_str)
30+
31+
32+
is_tf_keras = False
33+
str_tk_keras = os.environ.get('TF_KERAS', None)
34+
if str_tk_keras is None:
35+
# With tensorflow 2.x, be default we loaded tf.keras as the framework, instead of Keras
36+
is_tf_keras = not is_tensorflow_older_than('2.0.0')
37+
else:
38+
is_tf_keras = str_tk_keras != '0'
39+
40+
if is_tf_keras:
41+
from tensorflow.python import keras
42+
else:
43+
try:
44+
import keras
45+
46+
if keras.Model == tensorflow.keras.Model: # since keras 2.4, keras and tf.keras is unified.
47+
is_tf_keras = True
48+
except ImportError:
49+
is_tf_keras = True
50+
from tensorflow.python import keras
51+
52+
53+
def is_keras_older_than(version_str):
54+
return StrictVersion(keras.__version__.split('-')[0]) < StrictVersion(version_str)
55+
56+
57+
def is_keras_later_than(version_str):
58+
return StrictVersion(keras.__version__.split('-')[0]) > StrictVersion(version_str)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import os
4+
import tensorflow as _tf
5+
6+
from distutils.version import StrictVersion
7+
8+
is_tf2 = StrictVersion(_tf.__version__.split('-')[0]) >= StrictVersion('2.0.0')
9+
10+
11+
def normalize_tensor_shape(tensor_shape):
12+
if is_tf2:
13+
return [d for d in tensor_shape]
14+
else:
15+
return [d.value for d in tensor_shape]
16+
17+
18+
def dump_graph_into_tensorboard(tf_graph):
19+
# type: (_tf.Graph) -> None
20+
_tb_log_dir = os.environ.get('TB_LOG_DIR')
21+
if _tb_log_dir:
22+
if is_tf2:
23+
from tensorflow.python.ops.summary_ops_v2 import graph as write_graph
24+
pb_visual_writer = _tf.summary.create_file_writer(_tb_log_dir)
25+
with pb_visual_writer.as_default():
26+
write_graph(tf_graph)
27+
else:
28+
from tensorflow.python.summary import summary
29+
pb_visual_writer = summary.FileWriter(_tb_log_dir)
30+
pb_visual_writer.add_graph(tf_graph)
31+
32+
33+
if is_tf2:
34+
tensorflow = _tf.compat.v1
35+
36+
def is_subclassed(layer):
37+
"""Returns True if the object is a subclassed layer or subclassed model."""
38+
return (layer.__module__.find('keras.engine') == -1 and
39+
layer.__module__.find('keras.layers') == -1)
40+
else:
41+
tensorflow = _tf
42+
43+
def is_subclassed(layer):
44+
return False

keras2onnx_tests/test_cgan.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import tensorflow as tf
5+
import mock_keras2onnx
6+
import numpy as np
7+
from mock_keras2onnx.proto import keras, is_tf_keras
8+
from tf2onnx.keras2onnx_api import convert_keras
9+
from distutils.version import StrictVersion
10+
11+
Activation = keras.layers.Activation
12+
BatchNormalization = keras.layers.BatchNormalization
13+
Dense = keras.layers.Dense
14+
Dropout = keras.layers.Dropout
15+
Embedding = keras.layers.Embedding
16+
Flatten = keras.layers.Flatten
17+
Input = keras.layers.Input
18+
LeakyReLU = keras.layers.LeakyReLU
19+
multiply = keras.layers.multiply
20+
Reshape = keras.layers.Reshape
21+
UpSampling2D = keras.layers.UpSampling2D
22+
23+
Sequential = keras.models.Sequential
24+
Model = keras.models.Model
25+
26+
27+
# From https://github.com/eriklindernoren/Keras-GAN/blob/master/cgan/cgan.py
28+
class CGAN():
29+
def __init__(self):
30+
# Input shape
31+
self.img_rows = 28
32+
self.img_cols = 28
33+
self.channels = 1
34+
self.img_shape = (self.img_rows, self.img_cols, self.channels)
35+
self.num_classes = 10
36+
self.latent_dim = 100
37+
38+
# Build and compile the discriminator
39+
self.discriminator = self.build_discriminator()
40+
41+
# Build the generator
42+
self.generator = self.build_generator()
43+
44+
# The generator takes noise and the target label as input
45+
# and generates the corresponding digit of that label
46+
noise = Input(shape=(self.latent_dim,))
47+
label = Input(shape=(1,))
48+
img = self.generator([noise, label])
49+
50+
# For the combined model we will only train the generator
51+
self.discriminator.trainable = False
52+
53+
# The discriminator takes generated image as input and determines validity
54+
# and the label of that image
55+
valid = self.discriminator([img, label])
56+
57+
# The combined model (stacked generator and discriminator)
58+
# Trains generator to fool discriminator
59+
self.combined = Model([noise, label], valid)
60+
61+
def get_model(self):
62+
return self.combined
63+
64+
def build_generator(self):
65+
model = Sequential()
66+
67+
model.add(Dense(256, input_dim=self.latent_dim))
68+
69+
model.add(LeakyReLU(alpha=0.2))
70+
model.add(BatchNormalization(momentum=0.8))
71+
model.add(Dense(512))
72+
model.add(LeakyReLU(alpha=0.2))
73+
model.add(BatchNormalization(momentum=0.8))
74+
model.add(Dense(1024))
75+
model.add(LeakyReLU(alpha=0.2))
76+
model.add(BatchNormalization(momentum=0.8))
77+
78+
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
79+
model.add(Reshape(self.img_shape))
80+
81+
noise = Input(shape=(self.latent_dim,))
82+
label = Input(shape=(1,), dtype='int32')
83+
label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
84+
85+
model_input = multiply([noise, label_embedding])
86+
img = model(model_input)
87+
88+
return Model([noise, label], img)
89+
90+
def build_discriminator(self):
91+
model = Sequential()
92+
93+
model.add(Dense(512, input_dim=np.prod(self.img_shape)))
94+
model.add(LeakyReLU(alpha=0.2))
95+
model.add(Dense(512))
96+
model.add(LeakyReLU(alpha=0.2))
97+
model.add(Dropout(0.4))
98+
model.add(Dense(512))
99+
model.add(LeakyReLU(alpha=0.2))
100+
model.add(Dropout(0.4))
101+
model.add(Dense(1, activation='sigmoid'))
102+
103+
model.add(Dense(1, activation='sigmoid'))
104+
105+
img = Input(shape=self.img_shape)
106+
label = Input(shape=(1,), dtype='int32')
107+
108+
label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
109+
flat_img = Flatten()(img)
110+
111+
model_input = multiply([flat_img, label_embedding])
112+
113+
validity = model(model_input)
114+
115+
return Model([img, label], validity)
116+
117+
118+
@pytest.mark.skipif(mock_keras2onnx.proto.tfcompat.is_tf2 and is_tf_keras, reason="Tensorflow 1.x only tests.")
119+
@pytest.mark.skipif(is_tf_keras and StrictVersion(tf.__version__.split('-')[0]) < StrictVersion("1.14.0"),
120+
reason="Not supported before tensorflow 1.14.0 for tf_keras")
121+
def test_CGAN(runner):
122+
keras_model = CGAN().combined
123+
batch = 5
124+
x = np.random.rand(batch, 100).astype(np.float32)
125+
y = np.random.rand(batch, 1).astype(np.float32)
126+
expected = keras_model.predict([x, y])
127+
onnx_model = convert_keras(keras_model, keras_model.name)
128+
assert runner(onnx_model.graph.name, onnx_model,
129+
{keras_model.input_names[0]: x, keras_model.input_names[1]: y}, expected)

0 commit comments

Comments
 (0)