Skip to content

Commit 3324fec

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into tflfft
2 parents 7087ee9 + becdcba commit 3324fec

File tree

6 files changed

+80
-52
lines changed

6 files changed

+80
-52
lines changed

tests/backend_test_base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,14 @@ def freeze_and_run_tf(self, func, feed_dict, outputs, as_session, premade_placeh
190190
graph_def = freeze_session(sess,
191191
input_names=list(feed_dict.keys()),
192192
output_names=outputs)
193-
table_names, key_dtypes, value_dtypes = get_hash_table_info(graph_def)
193+
table_info = get_hash_table_info(graph_def)
194194
initialized_tables = {}
195-
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
196-
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
197-
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
198-
initialized_tables[n] = (sess.run(k), sess.run(v))
195+
for info in table_info:
196+
if info.shared_name is None:
197+
continue
198+
h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name)
199+
k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype)
200+
initialized_tables[info.shared_name] = (sess.run(k), sess.run(v))
199201

200202
tf_reset_default_graph()
201203
with tf_session() as sess:

tests/common.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -286,19 +286,18 @@ def check_tf_min_version(min_required_version, message=""):
286286

287287

288288
def skip_tf_versions(excluded_versions, message=""):
289-
""" Skip if tf_version SEMANTICALLY matches any of excluded_versions. """
289+
""" Skip if tf_version matches any of excluded_versions. """
290+
if not isinstance(excluded_versions, list):
291+
excluded_versions = [excluded_versions]
290292
config = get_test_config()
291293
condition = False
292294
reason = _append_message("conversion excludes tf {}".format(excluded_versions), message)
293295

294-
current_tokens = str(config.tf_version).split('.')
295296
for excluded_version in excluded_versions:
296-
exclude_tokens = excluded_version.split('.')
297-
# assume len(exclude_tokens) <= len(current_tokens)
298-
for i, exclude in enumerate(exclude_tokens):
299-
if not current_tokens[i] == exclude:
300-
break
301-
condition = True
297+
# tf version with same specificity as excluded_version
298+
tf_version = '.'.join(str(config.tf_version).split('.')[:excluded_version.count('.') + 1])
299+
if excluded_version == tf_version:
300+
condition = True
302301

303302
return unittest.skipIf(condition, reason)
304303

tests/test_lstm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,8 @@ def func(x):
674674
feed_dict = {"input_1:0": x_val}
675675
input_names_with_port = ["input_1:0"]
676676
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
677-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
677+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
678+
require_lstm_count=2)
678679

679680
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
680681
@skip_tf_versions("2.1", "Bug in TF 2.1")
@@ -721,7 +722,8 @@ def func(x, y1, y2):
721722
feed_dict = {"input_1:0": x_val, "input_2:0": seq_len_val, "input_3:0": seq_len_val}
722723
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
723724
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
724-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
725+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
726+
require_lstm_count=2)
725727

726728

727729
if __name__ == '__main__':

tf2onnx/optimizer/einsum_optimizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,20 @@ def _preprocess(a, axis):
9090
return (np.concatenate(targs, axis),)
9191

9292
def _op_gemm(self, a, b, c=None, alpha=None, beta=None, # pylint: disable=C0103
93-
transA=False, transB=False): # pylint: disable=C0103
93+
transA=None, transB=None): # pylint: disable=C0103
9494
"Runtime for operator."
9595
if alpha is not None:
9696
alpha = alpha.f
9797
if beta is not None:
9898
beta = beta.f
99+
if transA is None:
100+
transA = False
101+
else:
102+
transA = transA.i
103+
if transB is None:
104+
transB = False
105+
else:
106+
transB = transB.i
99107

100108
def _gemm00(a, b, c, alpha, beta):
101109
o = np.dot(a, b) * alpha

tf2onnx/tf_loader.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""Methods to load tensorflow graph from graphdef, checkpoint or saved_model."""
55

66
import logging
7+
import uuid
78
from distutils.version import LooseVersion
89

910
import tensorflow as tf
@@ -15,7 +16,8 @@
1516
from tensorflow.python.util import compat
1617

1718
from tf2onnx import utils
18-
from tf2onnx.tf_utils import get_tf_version, tflist_to_onnx, get_hash_table_info, replace_placeholders_with_tables
19+
from tf2onnx.tf_utils import (get_tf_version, tflist_to_onnx, get_hash_table_info, replace_placeholders_with_tables,
20+
HashTableInfo)
1921

2022
logger = logging.getLogger(__name__)
2123

@@ -184,7 +186,7 @@ def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
184186
err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model."
185187

186188
# Avoid errors due to bug in TF freezing
187-
removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \
189+
removed_resource_to_placeholder, placeholder_to_resource, graph_captures_copy, func_captures_copy = \
188190
_remove_non_variable_resources_from_captures(concrete_func)
189191

190192
try:
@@ -197,16 +199,28 @@ def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
197199
# We might be returning the concrete_func so let's put it back in working order
198200
_restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy)
199201

200-
table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
202+
table_info = get_hash_table_info(frozen_graph)
201203
placeholder_to_table_info = {}
202-
_get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
204+
_get_hash_table_info_from_trackable(trackable, table_info,
203205
removed_resource_to_placeholder, placeholder_to_table_info)
204206

205207
initialized_tables = {}
206-
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
207-
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
208+
for info in table_info:
209+
if info.shared_name is not None:
210+
h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name)
211+
n = info.shared_name
212+
elif info.resource_input in placeholder_to_resource and info.resource_input not in placeholder_to_table_info:
213+
# We found a lookup op with no corresponding HashTable op, but we can associate the placeholder input
214+
# from the op with the resource handle from graph captures and make up a shared_name
215+
h = placeholder_to_resource[info.resource_input]
216+
n = str(uuid.uuid4()).encode()
217+
info.shared_name = n
218+
placeholder_to_table_info[info.resource_input] = info
219+
else:
220+
# Found a lookup op but the corresponding HashTable op has already been found and processed.
221+
continue
208222
try:
209-
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
223+
k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype)
210224
initialized_tables[n] = (k.numpy(), v.numpy())
211225
except Exception: # pylint: disable=broad-except
212226
logger.warning("Could not initialize table with shared_name = %r", n)
@@ -260,14 +274,14 @@ def freeze_session(sess, input_names=None, output_names=None, get_tables=False):
260274
for node in graph_def.node:
261275
node.device = ""
262276
graph_def = convert_variables_to_constants(sess, graph_def, output_node_names)
263-
table_names, key_dtypes, value_dtypes = get_hash_table_info(graph_def)
277+
table_info = get_hash_table_info(graph_def)
264278
if get_tables:
265279
initialized_tables = {}
266280
tf.tables_initializer().run(session=sess)
267-
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
268-
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
281+
for info in table_info:
282+
h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name)
269283
try:
270-
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
284+
k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype)
271285
k, v = sess.run([k, v])
272286
initialized_tables[n] = (k, v)
273287
except Exception: # pylint: disable=broad-except
@@ -403,7 +417,7 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
403417
return frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename
404418

405419

406-
def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
420+
def _get_hash_table_info_from_trackable(trackable, table_info,
407421
removed_resource_to_placeholder, placeholder_to_table_info):
408422
# pylint: disable=protected-access
409423
stack = [trackable]
@@ -420,26 +434,22 @@ def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, valu
420434
continue
421435
for t in r.__dict__.values() if hasattr(r, '__dict__') else []:
422436
if isinstance(t, TfStaticHashTableType) and hasattr(t, '_shared_name'):
423-
table_names.append(t._shared_name.encode())
424-
key_dtypes.append(t.key_dtype.as_datatype_enum)
425-
value_dtypes.append(t.value_dtype.as_datatype_enum)
437+
info = HashTableInfo(t._shared_name.encode(), t.key_dtype.as_datatype_enum,
438+
t.value_dtype.as_datatype_enum)
439+
table_info.append(info)
426440
table_handle = id(t.resource_handle)
427441
if table_handle in removed_resource_to_placeholder:
428-
table_info = (table_names[-1], key_dtypes[-1], value_dtypes[-1])
429-
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
442+
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = info
430443
if isinstance(r, TfRestoredResourceType) and hasattr(r, '_create_resource'):
431444
try:
432445
table_handle = id(r.resource_handle)
433446
except Exception: # pylint: disable=broad-except
434447
continue
435448
initializer = r._create_resource.concrete_functions[0].function_def
436-
new_names, new_k_dtypes, new_v_dtypes = get_hash_table_info(initializer.node_def)
437-
table_names.extend(new_names)
438-
key_dtypes.extend(new_k_dtypes)
439-
value_dtypes.extend(new_v_dtypes)
440-
if table_handle in removed_resource_to_placeholder and len(new_names) == 1:
441-
table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0])
442-
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
449+
new_table_info = get_hash_table_info(initializer.node_def)
450+
table_info.extend(new_table_info)
451+
if table_handle in removed_resource_to_placeholder and len(new_table_info) == 1:
452+
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = new_table_info[0]
443453

444454

445455
def _remove_non_variable_resources_from_captures(concrete_func):
@@ -449,6 +459,7 @@ def _remove_non_variable_resources_from_captures(concrete_func):
449459
"""
450460
# pylint: disable=protected-access
451461
resource_id_to_placeholder = {}
462+
placeholder_to_resource = {}
452463
graph_captures_copy = None
453464
func_captures_copy = None
454465
if hasattr(concrete_func.graph, '_captures') and hasattr(concrete_func, '_captured_inputs'):
@@ -459,6 +470,7 @@ def _remove_non_variable_resources_from_captures(concrete_func):
459470
val_tensor, name_tensor = v
460471
if val_tensor.dtype == tf.resource and id(val_tensor) not in variable_handles:
461472
resource_id_to_placeholder[id(val_tensor)] = name_tensor.name.split(':')[0]
473+
placeholder_to_resource[name_tensor.name.split(':')[0]] = val_tensor
462474
del concrete_func.graph._captures[k]
463475
for i in reversed(range(len(concrete_func._captured_inputs))):
464476
if concrete_func._captured_inputs[i] is val_tensor:
@@ -472,7 +484,7 @@ def _remove_non_variable_resources_from_captures(concrete_func):
472484
else:
473485
logger.warning(
474486
"Could not search for non-variable resources. Concrete function internal representation may have changed.")
475-
return resource_id_to_placeholder, graph_captures_copy, func_captures_copy
487+
return resource_id_to_placeholder, placeholder_to_resource, graph_captures_copy, func_captures_copy
476488

477489

478490
def _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy):

tf2onnx/tf_utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,13 @@ def is_huge_shape(x):
278278
logger.info("Computed %d values for constant folding", len(outputs_to_values))
279279
return outputs_to_values, outputs_to_dtypes
280280

281+
class HashTableInfo:
282+
def __init__(self, shared_name, key_dtype, val_dtype, resource_input=None):
283+
self.shared_name = shared_name
284+
self.key_dtype = key_dtype
285+
self.val_dtype = val_dtype
286+
self.resource_input = resource_input
287+
281288
def get_hash_table_info(nodes_or_graph_def):
282289
"""
283290
Return lists of the shared_names, key_dtypes, and value_dtypes of all hash tables declared in the graph_def
@@ -287,18 +294,16 @@ def get_hash_table_info(nodes_or_graph_def):
287294
nodes = nodes_or_graph_def.node
288295
else:
289296
nodes = nodes_or_graph_def
290-
names = []
291-
key_dtypes = []
292-
val_dtypes = []
297+
info = []
293298
for n in nodes:
299+
if n.op == "LookupTableFindV2":
300+
info.append(HashTableInfo(None, n.attr['Tin'].type, n.attr['Tout'].type, n.input[0]))
294301
if n.op in ["HashTableV2", "MutableHashTableV2"]:
295302
if all(k in n.attr for k in ['shared_name', 'key_dtype', 'value_dtype']):
296303
name = n.attr['shared_name'].s
297304
if name != b'':
298-
names.append(name)
299-
key_dtypes.append(n.attr['key_dtype'].type)
300-
val_dtypes.append(n.attr['value_dtype'].type)
301-
return names, key_dtypes, val_dtypes
305+
info.append(HashTableInfo(name, n.attr['key_dtype'].type, n.attr['value_dtype'].type))
306+
return info
302307

303308
def replace_placeholders_with_tables(graph_def, placeholder_to_table_info):
304309
"""
@@ -307,13 +312,13 @@ def replace_placeholders_with_tables(graph_def, placeholder_to_table_info):
307312
"""
308313
for n in graph_def.node:
309314
if n.op == "Placeholder" and n.name in placeholder_to_table_info:
310-
name, key_dtype, val_dtype = placeholder_to_table_info[n.name]
315+
info = placeholder_to_table_info[n.name]
311316
for a in list(n.attr):
312317
del n.attr[a]
313318
n.op = "HashTableV2"
314-
n.attr['shared_name'].s = name
315-
n.attr['key_dtype'].type = key_dtype
316-
n.attr['value_dtype'].type = val_dtype
319+
n.attr['shared_name'].s = info.shared_name
320+
n.attr['key_dtype'].type = info.key_dtype
321+
n.attr['value_dtype'].type = info.val_dtype
317322

318323
def read_tf_node_def_attrs(node_def, input_dtypes, input_shapes):
319324
"""Given a tf node def, returns a dict of attribute names to values"""

0 commit comments

Comments
 (0)