Skip to content

Commit 8c081ef

Browse files
Implement fused tfjs ops (#1620)
* Squashed commit of the following: commit de562af Author: Tom Wildenhain <[email protected]> Date: Tue Jul 13 18:33:56 2021 -0700 Pylint commit 39dc6d2 Author: Tom Wildenhain <[email protected]> Date: Tue Jul 13 18:25:38 2021 -0700 Rename tfjs_runner.py Signed-off-by: Tom Wildenhain <[email protected]> commit cdcccfc Author: Tom Wildenhain <[email protected]> Date: Tue Jul 13 18:13:08 2021 -0700 Enable tfjs unit tests Signed-off-by: Tom Wildenhain <[email protected]> commit b6efae4 Merge: de5ac2c c5aba9a Author: Tom Wildenhain <[email protected]> Date: Mon Jul 12 22:00:14 2021 -0700 Merge branch 'master' into tom/tfjs2 commit de5ac2c Author: Tom Wildenhain <[email protected]> Date: Mon Jul 12 21:59:17 2021 -0700 Exclude tfjs tests with broken tf->tfjs conversions Signed-off-by: Tom Wildenhain <[email protected]> commit d92d7a3 Author: Tom Wildenhain <[email protected]> Date: Mon Jul 12 13:55:44 2021 -0700 TFJS fixes Signed-off-by: Tom Wildenhain <[email protected]> commit 103761f Merge: e4889a7 b65b05e Author: Tom Wildenhain <[email protected]> Date: Fri Jul 9 13:34:39 2021 -0700 Merge branch 'master' into tom/tfjs2 commit e4889a7 Merge: a0c2ee4 664430c Author: Tom Wildenhain <[email protected]> Date: Thu Jul 8 21:17:29 2021 -0700 Merge branch 'master' into tom/tfjs2 commit a0c2ee4 Author: Tom Wildenhain <[email protected]> Date: Thu Jul 8 21:17:09 2021 -0700 Improve testing for tfjs Signed-off-by: Tom Wildenhain <[email protected]> commit ca0f9bd Author: Tom Wildenhain <[email protected]> Date: Thu Jul 8 20:41:26 2021 -0700 Add tests for tfjs Signed-off-by: Tom Wildenhain <[email protected]> commit 91ffbca Author: Tom Wildenhain <[email protected]> Date: Wed Jul 7 20:50:15 2021 -0700 Implement tfjs conversion Signed-off-by: Tom Wildenhain <[email protected]> commit 5190778 Merge: e04b40f 448d61d Author: Tom Wildenhain <[email protected]> Date: Wed Jul 7 14:44:08 2021 -0700 Merge branch 'master' into tom/tfjs2 commit e04b40f Merge: 58ed277 dc1d0e1 Author: Tom Wildenhain <[email protected]> Date: Tue Jul 6 20:40:21 2021 -0700 Merge branch 'tom/refactor_tf2onnx2' into tom/tfjs2 commit dc1d0e1 Author: Tom Wildenhain <[email protected]> Date: Tue Jul 6 19:44:20 2021 -0700 Refactor tflite/tf logic out of tfonnx into tf_utils and tflite_utils Signed-off-by: Tom Wildenhain <[email protected]> commit 58ed277 Merge: b4c3186 552b0a2 Author: Tom Wildenhain <[email protected]> Date: Tue Jul 6 18:56:02 2021 -0700 Merge branch 'master' into tom/tfjs2 commit b4c3186 Author: Tom Wildenhain <[email protected]> Date: Fri Jul 2 18:31:49 2021 -0700 WIP Signed-off-by: Tom Wildenhain <[email protected]> * reorganize Signed-off-by: Tom Wildenhain <[email protected]> * pylint Signed-off-by: Tom Wildenhain <[email protected]> * Improve documentation Signed-off-by: Tom Wildenhain <[email protected]> * Pylint Signed-off-by: Tom Wildenhain <[email protected]> * Implement conversion of _Fused ops seen in tfjs Signed-off-by: Tom Wildenhain <[email protected]> * Enable tests with fused ops Signed-off-by: Tom Wildenhain <[email protected]>
1 parent fc91639 commit 8c081ef

File tree

4 files changed

+38
-3
lines changed

4 files changed

+38
-3
lines changed

tests/test_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2770,7 +2770,6 @@ def func(x, mean, offset, var):
27702770
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: mean_val, _INPUT2: offset_val, _INPUT3: var_val})
27712771

27722772
@check_opset_min_version(7, "batchnorm")
2773-
@skip_tfjs("Unsupported _FusedConv2D op") # TODO: implement this
27742773
def test_conv2d_batchnorm_fusion(self):
27752774
x_shape = [1, 28, 28, 2]
27762775
x_val = np.random.random_sample(x_shape).astype(np.float32)
@@ -4272,7 +4271,6 @@ def func(a, b, c):
42724271
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
42734272

42744273
# test for gemm pattern4: A*B + C [addbias] - 1D bias!
4275-
@skip_tfjs("Unsupported _FusedMatMul op") # TODO: implement this
42764274
def test_gemm_pattern4(self):
42774275
max_number = 10
42784276
m = np.random.randint(max_number)

tf2onnx/rewriter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tf2onnx.rewriter.layer_normalization_rewriter import rewrite_layer_normalization
2323
from tf2onnx.rewriter.ragged_variant_shape_rewriter import rewrite_ragged_variant_shape
2424
from tf2onnx.rewriter.lstm_tf2_rewriter import rewriter_lstm_tf2
25+
from tf2onnx.rewriter.fused_op_rewriter import rewrite_fused_ops
2526

2627

2728
__all__ = [
@@ -48,5 +49,6 @@
4849
"rewrite_layer_normalization",
4950
"rewrite_conv_dilations",
5051
"rewrite_ragged_variant_shape",
51-
"rewriter_lstm_tf2"
52+
"rewriter_lstm_tf2",
53+
"rewrite_fused_ops",
5254
]

tf2onnx/rewriter/fused_op_rewriter.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.rewriter.fused_op_rewriter - rewrite tensorflow _Fused ops from grappler into other tf ops
6+
"""
7+
8+
9+
# pylint: disable=missing-docstring
10+
11+
12+
def rewrite_fused_ops(g, ops):
13+
for node in ops:
14+
if node.type in ["_FusedConv2D", "_FusedMatMul"]:
15+
op_types = [op.decode() for op in node.get_attr_value("fused_ops")]
16+
extra_inputs = node.input[2:]
17+
g.replace_inputs(node, node.input[:2])
18+
last_output = node.output[0]
19+
node.type = node.type.replace("_Fused", "")
20+
dtype = g.get_dtype(node.output[0])
21+
shape = g.get_shape(node.output[0])
22+
first_node = None
23+
for op in op_types:
24+
new_node = g.make_node(op, [last_output] + extra_inputs, skip_conversion=False,
25+
op_name_scope=node.name, dtypes=[dtype], shapes=[shape])
26+
last_output = new_node.output[0]
27+
if first_node is None:
28+
first_node = new_node
29+
extra_inputs = []
30+
31+
consumers = [n for n in g.find_output_consumers(node.output[0]) if n != first_node]
32+
g.replace_all_inputs(node.output[0], last_output, consumers)
33+
34+
return g.get_nodes()

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ def compat_handler(ctx, node, **kwargs):
554554
# single directional
555555
rewrite_constant_fold,
556556
rewrite_quantize_and_dequantize,
557+
rewrite_fused_ops,
557558
rewrite_transpose,
558559
rewrite_flatten,
559560
rewrite_random_uniform,

0 commit comments

Comments
 (0)