|
23 | 23 | from common import * # pylint: disable=wildcard-import,unused-wildcard-import
|
24 | 24 | from tf2onnx import constants, utils
|
25 | 25 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
|
26 |
| -from tf2onnx.tf_loader import is_tf2 |
| 26 | +from tf2onnx.tf_loader import is_tf2, tf_placeholder_with_default |
27 | 27 | from tf2onnx.onnx_opset.signal import make_dft_constant
|
28 | 28 |
|
29 | 29 | # pylint: disable=missing-docstring,invalid-name,unused-argument,function-redefined,cell-var-from-loop
|
@@ -711,24 +711,22 @@ def func(x):
|
711 | 711 | return tf.identity(x, name=_TFOUTPUT)
|
712 | 712 | self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
|
713 | 713 |
|
714 |
| - #@unittest.skip("doesn't work with the new ut func interface, fix later") |
715 |
| - #def test_placeholder_with_default_use_default(self): |
716 |
| - # x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2)) |
717 |
| - # def func(): |
718 |
| - # x = tf.constant(x_val, name="x") |
719 |
| - # y = tf_placeholder_with_default(x, x_val.shape, name=_TFINPUT) |
720 |
| - # return tf.identity(y, name=_TFOUTPUT) |
721 |
| - # self._run_test_case(func, [_OUTPUT], {}) |
722 |
| - |
723 |
| - #@unittest.skip("doesn't work with the new ut func interface, fix later") |
724 |
| - #def test_placeholder_with_default_use_feed(self): |
725 |
| - # x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2)) |
726 |
| - # def func(): |
727 |
| - # x = tf.constant(x_val, name="x") |
728 |
| - # y = tf_placeholder_with_default(x, x_val.shape, name=_TFINPUT) |
729 |
| - # return tf.identity(y, name=_TFOUTPUT) |
730 |
| - # x_feed_val = np.array([11.0, 22.0, -33.0, -44.0], dtype=np.float32).reshape((2, 2)) |
731 |
| - # self._run_test_case(func, [_OUTPUT], {_INPUT: x_feed_val}) |
| 714 | + def test_placeholder_with_default_use_default(self): |
| 715 | + x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2)) |
| 716 | + def func(): |
| 717 | + x = tf.constant(x_val, name="x") |
| 718 | + y = tf_placeholder_with_default(x, x_val.shape, name=_TFINPUT) |
| 719 | + return tf.identity(y, name=_TFOUTPUT) |
| 720 | + self._run_test_case(func, [_OUTPUT], {}, as_session=True, premade_placeholders=True) |
| 721 | + |
| 722 | + def test_placeholder_with_default_use_feed(self): |
| 723 | + x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2)) |
| 724 | + def func(): |
| 725 | + x = tf.constant(x_val, name="x") |
| 726 | + y = tf_placeholder_with_default(x, x_val.shape, name=_TFINPUT) |
| 727 | + return tf.identity(y, name=_TFOUTPUT) |
| 728 | + x_feed_val = np.array([11.0, 22.0, -33.0, -44.0], dtype=np.float32).reshape((2, 2)) |
| 729 | + self._run_test_case(func, [_OUTPUT], {_INPUT: x_feed_val}, as_session=True, premade_placeholders=True) |
732 | 730 |
|
733 | 731 | @check_onnxruntime_incompatibility("Add")
|
734 | 732 | def test_add_bcast(self):
|
|
0 commit comments