Skip to content

Commit 7900e3a

Browse files
Reenabled placeholderwithdefault tests (#1262)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent f6acec4 commit 7900e3a

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed

tests/backend_test_base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def run_backend(self, g, outputs, input_dict, large_model=False):
9999
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5,
100100
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True,
101101
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False,
102-
large_model=False):
102+
large_model=False, premade_placeholders=False):
103103
# optional - passed to process_tf_graph
104104
if process_args is None:
105105
process_args = {}
@@ -145,8 +145,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
145145
with tf_session() as sess:
146146
tf_set_random_seed(1)
147147
input_list = []
148-
for k, v in clean_feed_dict.items():
149-
input_list.append(tf_placeholder(name=k, shape=v.shape, dtype=tf.as_dtype(v.dtype)))
148+
if not premade_placeholders:
149+
for k, v in clean_feed_dict.items():
150+
input_list.append(tf_placeholder(name=k, shape=v.shape, dtype=tf.as_dtype(v.dtype)))
150151
func(*input_list)
151152
variables_lib.global_variables_initializer().run()
152153
tf_tables_initializer().run()

tests/test_backend.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
2424
from tf2onnx import constants, utils
2525
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
26-
from tf2onnx.tf_loader import is_tf2
26+
from tf2onnx.tf_loader import is_tf2, tf_placeholder_with_default
2727
from tf2onnx.onnx_opset.signal import make_dft_constant
2828

2929
# pylint: disable=missing-docstring,invalid-name,unused-argument,function-redefined,cell-var-from-loop
@@ -711,24 +711,22 @@ def func(x):
711711
return tf.identity(x, name=_TFOUTPUT)
712712
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
713713

714-
#@unittest.skip("doesn't work with the new ut func interface, fix later")
715-
#def test_placeholder_with_default_use_default(self):
716-
# x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
717-
# def func():
718-
# x = tf.constant(x_val, name="x")
719-
# y = tf_placeholder_with_default(x, x_val.shape, name=_TFINPUT)
720-
# return tf.identity(y, name=_TFOUTPUT)
721-
# self._run_test_case(func, [_OUTPUT], {})
722-
723-
#@unittest.skip("doesn't work with the new ut func interface, fix later")
724-
#def test_placeholder_with_default_use_feed(self):
725-
# x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
726-
# def func():
727-
# x = tf.constant(x_val, name="x")
728-
# y = tf_placeholder_with_default(x, x_val.shape, name=_TFINPUT)
729-
# return tf.identity(y, name=_TFOUTPUT)
730-
# x_feed_val = np.array([11.0, 22.0, -33.0, -44.0], dtype=np.float32).reshape((2, 2))
731-
# self._run_test_case(func, [_OUTPUT], {_INPUT: x_feed_val})
714+
def test_placeholder_with_default_use_default(self):
715+
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
716+
def func():
717+
x = tf.constant(x_val, name="x")
718+
y = tf_placeholder_with_default(x, x_val.shape, name=_TFINPUT)
719+
return tf.identity(y, name=_TFOUTPUT)
720+
self._run_test_case(func, [_OUTPUT], {}, as_session=True, premade_placeholders=True)
721+
722+
def test_placeholder_with_default_use_feed(self):
723+
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
724+
def func():
725+
x = tf.constant(x_val, name="x")
726+
y = tf_placeholder_with_default(x, x_val.shape, name=_TFINPUT)
727+
return tf.identity(y, name=_TFOUTPUT)
728+
x_feed_val = np.array([11.0, 22.0, -33.0, -44.0], dtype=np.float32).reshape((2, 2))
729+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_feed_val}, as_session=True, premade_placeholders=True)
732730

733731
@check_onnxruntime_incompatibility("Add")
734732
def test_add_bcast(self):

tf2onnx/tf_loader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def not_implemented_tf_placeholder(*args, **kwargs):
7373
tf_import_meta_graph = tf.compat.v1.train.import_meta_graph
7474
tf_gfile = tf.io.gfile
7575
tf_placeholder = tf.compat.v1.placeholder
76+
tf_placeholder_with_default = tf.compat.v1.placeholder_with_default
7677
extract_sub_graph = tf.compat.v1.graph_util.extract_sub_graph
7778
elif LooseVersion(tf.__version__) >= "1.13":
7879
# 1.13 introduced the compat namespace
@@ -83,6 +84,7 @@ def not_implemented_tf_placeholder(*args, **kwargs):
8384
tf_import_meta_graph = tf.compat.v1.train.import_meta_graph
8485
tf_gfile = tf.gfile
8586
tf_placeholder = tf.compat.v1.placeholder
87+
tf_placeholder_with_default = tf.compat.v1.placeholder_with_default
8688
extract_sub_graph = tf.compat.v1.graph_util.extract_sub_graph
8789
else:
8890
# older than 1.13
@@ -93,6 +95,7 @@ def not_implemented_tf_placeholder(*args, **kwargs):
9395
tf_import_meta_graph = tf.train.import_meta_graph
9496
tf_gfile = tf.gfile
9597
tf_placeholder = tf.placeholder
98+
tf_placeholder_with_default = tf.placeholder_with_default
9699
extract_sub_graph = tf.graph_util.extract_sub_graph
97100

98101

tf2onnx/tf_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def is_huge_shape(x):
220220
outputs_to_values[output_names[0]] = np.array(shape[i], dtype=np_dtype)
221221
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
222222
progress = True
223-
can_fold = node.type not in ['Enter']
223+
can_fold = node.type not in ['Enter', 'Placeholder', 'PlaceholderWithDefault']
224224
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
225225
# We can only fold nodes with a single output
226226
can_fold = can_fold and len(output_names) == 1 and output_names[0] not in outputs_to_values

0 commit comments

Comments
 (0)