Skip to content

Commit 71a3675

Browse files
committed
fix merge
Signed-off-by: xavier dupré <[email protected]>
2 parents 93d9d10 + 04d2488 commit 71a3675

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

tests/backend_test_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ def freeze_and_run_tf(self, func, feed_dict, outputs, as_session, premade_placeh
205205
return result, graph_def, initialized_tables
206206

207207
def convert_to_tfjs(self, graph_def_path, output_names):
208-
from tensorflowjs.converters import converter
208+
try:
209+
from tensorflowjs.converters import converter
210+
except ImportError:
211+
return None
209212
tfjs_path = os.path.join(self.test_data_directory, self._testMethodName + "_tfjs")
210213
try:
211214
converter.convert([graph_def_path, tfjs_path, '--input_format', 'tf_frozen_model',

tests/test_backend.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5559,6 +5559,37 @@ def func1(x):
55595559
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}, optimize=False)
55605560
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})
55615561

5562+
@check_tf_min_version("2.1")
5563+
@skip_tflite("TFlite errors on some attributes")
5564+
@check_opset_min_version(9, "string")
5565+
def test_asstring(self):
5566+
def func(x):
5567+
op_ = tf.strings.as_string(x)
5568+
return tf.identity(op_, name=_TFOUTPUT)
5569+
5570+
x_val = np.array([0, 1, 2, 3], dtype=np.int32)
5571+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5572+
5573+
x_val = np.array([0, 1, 2, 3], dtype=np.float32)
5574+
# can't check the values because in onnx they are padded with 0, in tf they are not
5575+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False)
5576+
5577+
@check_tf_min_version("2.1")
5578+
@skip_tflite("TFlite errors on some attributes")
5579+
@check_opset_min_version(9, "string")
5580+
def test_string_to_number(self):
5581+
def func(x):
5582+
op_ = tf.strings.to_number(x)
5583+
return tf.identity(op_, name=_TFOUTPUT)
5584+
5585+
# tf gets this wrong and returns fp32 instead of int
5586+
x_val = np.array("123", dtype=np.object)
5587+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5588+
5589+
x_val = np.array("123.1", dtype=np.object)
5590+
# can't check the values because in onnx they are padded with 0, in tf they are not
5591+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False)
5592+
55625593

55635594
if __name__ == '__main__':
55645595
cl = BackendTests()

tf2onnx/onnx_opset/tensor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3799,3 +3799,27 @@ def version_8(cls, ctx, node, **kwargs):
37993799
# broadcast by expanding
38003800
node.type = "Expand"
38013801
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
3802+
3803+
3804+
@tf_op("AsString")
3805+
class AsString:
3806+
@classmethod
3807+
def version_9(cls, ctx, node, **kwargs):
3808+
if (node.get_attr_value("precision") or node.get_attr_value("scientific") or node.get_attr_value("fill")):
3809+
logger.warning(
3810+
"ONNX does not support precision, scientific and fill attributes for AsString")
3811+
shapes = node.output_shapes
3812+
ctx.remove_node(node.name)
3813+
_ = ctx.make_node("Cast", node.input, name=node.name,
3814+
dtypes=[TensorProto.STRING], shapes=shapes, attr={"to": TensorProto.STRING})
3815+
3816+
3817+
@tf_op("StringToNumber")
3818+
class StringToNumber:
3819+
@classmethod
3820+
def version_9(cls, ctx, node, **kwargs):
3821+
shapes = node.output_shapes
3822+
dtypes = node.output_dtypes
3823+
ctx.remove_node(node.name)
3824+
_ = ctx.make_node("Cast", node.input, name=node.name,
3825+
dtypes=dtypes, shapes=shapes, attr={"to": dtypes[0]})

0 commit comments

Comments
 (0)