Skip to content

Commit b106603

Browse files
Made from_keras work for tf1 (#1529)
* Made from_keras work for tf1 Signed-off-by: Tom Wildenhain <[email protected]> * pylint Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 936cec6 commit b106603

File tree

3 files changed

+60
-17
lines changed

3 files changed

+60
-17
lines changed

tests/test_api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from onnx import helper
1414

1515
from common import check_tf_min_version, unittest_main, requires_custom_ops, check_opset_min_version
16+
from tf2onnx.tf_loader import is_tf2
1617
from backend_test_base import Tf2OnnxBackendTestBase
1718
import tf2onnx
1819

@@ -67,16 +68,16 @@ def _test_keras_api(self, large_model=False):
6768
ky1 = model.predict([x, n])
6869
self.assertAllClose(ky1, oy[0], rtol=0.3, atol=0.1)
6970

70-
@check_tf_min_version("2.0")
71+
@check_tf_min_version("1.15")
7172
def test_keras_api(self):
7273
self._test_keras_api(large_model=False)
7374

74-
@check_tf_min_version("2.2")
75+
@check_tf_min_version("1.15")
7576
def test_keras_api_large(self):
7677
self._test_keras_api(large_model=True)
7778

7879
@requires_custom_ops()
79-
@check_tf_min_version("2.0")
80+
@check_tf_min_version("1.15")
8081
@check_opset_min_version(11, "SparseToDense")
8182
def test_keras_hashtable(self):
8283

@@ -101,6 +102,8 @@ def test_keras_hashtable(self):
101102

102103
inp1 = np.array([[2.], [3.]], dtype=np.float32)
103104
inp2 = np.array([["a"], ["b"]], dtype=np.str)
105+
if not is_tf2():
106+
tf.keras.backend.get_session().run(tf.tables_initializer(name='init_all_tables'))
104107
k_res = model.predict([inp1, inp2])
105108
spec = (tf.TensorSpec((None, 1), dtype=tf.float32, name="f_inp"),
106109
tf.TensorSpec((None, 1), tf.string, name="s_inp"))

tf2onnx/convert.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,42 @@ def tensor_names_from_structed(concrete_func, input_names, output_names):
284284
return tensors_to_rename
285285

286286

287+
def _from_keras_tf1(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
288+
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None,
289+
target=None, large_model=False, output_path=None):
290+
"""from_keras for tf 1.15"""
291+
292+
input_names = [t.name for t in model.inputs]
293+
output_names = [t.name for t in model.outputs]
294+
tensors_to_rename = dict(zip(input_names, model.input_names))
295+
if len(set(model.output_names)) == len(model.output_names):
296+
# In very rare cases, keras has a bug where it will give multiple outputs the same name
297+
tensors_to_rename.update(zip(output_names, model.output_names))
298+
299+
sess = tf.keras.backend.get_session(model.outputs)
300+
301+
with tf.device("/cpu:0"):
302+
frozen_graph, initialized_tables = tf_loader.freeze_session(sess, input_names, output_names, get_tables=True)
303+
model_proto, external_tensor_storage = _convert_common(
304+
frozen_graph,
305+
name=model.name,
306+
continue_on_error=True,
307+
target=target,
308+
opset=opset,
309+
custom_op_handlers=custom_ops,
310+
extra_opset=extra_opset,
311+
shape_override=shape_override,
312+
input_names=input_names,
313+
output_names=output_names,
314+
inputs_as_nchw=inputs_as_nchw,
315+
large_model=large_model,
316+
tensors_to_rename=tensors_to_rename,
317+
initialized_tables=initialized_tables,
318+
output_path=output_path)
319+
320+
return model_proto, external_tensor_storage
321+
322+
287323
def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
288324
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None,
289325
target=None, large_model=False, output_path=None):
@@ -306,7 +342,8 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
306342
An ONNX model_proto and an external_tensor_storage dict.
307343
"""
308344
if LooseVersion(tf.__version__) < "2.0":
309-
raise NotImplementedError("from_keras requires tf-2.0 or newer")
345+
return _from_keras_tf1(model, input_signature, opset, custom_ops, custom_op_handlers, custom_rewriter,
346+
inputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)
310347

311348
from tensorflow.python.keras.saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel
312349

tf2onnx/tf_loader.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def from_function(func, input_names, output_names, large_model=False):
214214
return graph_def
215215

216216

217-
def freeze_session(sess, input_names=None, output_names=None):
217+
def freeze_session(sess, input_names=None, output_names=None, get_tables=False):
218218
"""Freezes the state of a session into a pruned computation graph."""
219219
output_node_names = [i.split(':')[:-1][0] for i in output_names]
220220
keep_var_names = [i.split(':')[:-1][0] for i in input_names]
@@ -226,6 +226,19 @@ def freeze_session(sess, input_names=None, output_names=None):
226226
for node in graph_def.node:
227227
node.device = ""
228228
graph_def = convert_variables_to_constants(sess, graph_def, output_node_names)
229+
table_names, key_dtypes, value_dtypes = get_hash_table_info(graph_def)
230+
if get_tables:
231+
initialized_tables = {}
232+
tf.tables_initializer().run(session=sess)
233+
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
234+
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
235+
try:
236+
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
237+
k, v = sess.run([k, v])
238+
initialized_tables[n] = (k, v)
239+
except Exception: # pylint: disable=broad-except
240+
logger.warning("Could not initialize table with shared_name = %r", n)
241+
return graph_def, initialized_tables
229242
return graph_def
230243

231244

@@ -348,18 +361,8 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
348361
if output_tensor.name not in output_names:
349362
output_names.append(output_tensor.name)
350363
tensors_to_rename[output_tensor.name] = structured_name
351-
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
352-
table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
353-
initialized_tables = {}
354-
tf.tables_initializer().run()
355-
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
356-
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
357-
try:
358-
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
359-
k, v = sess.run([k, v])
360-
initialized_tables[n] = (k, v)
361-
except Exception: # pylint: disable=broad-except
362-
logger.warning("Could not initialize table with shared_name = %r", n)
364+
frozen_graph, initialized_tables = \
365+
freeze_session(sess, input_names=input_names, output_names=output_names, get_tables=True)
363366
return frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename
364367

365368

0 commit comments

Comments
 (0)