Skip to content

Commit b540e3c

Browse files
Added rewriter for specific pattern in a 1P model (#1380)
* Added rewriter for specific pattern in a 1P model Signed-off-by: Tom Wildenhain <[email protected]> * Improved documentation of rewriter Signed-off-by: Tom Wildenhain <[email protected]>
1 parent acd34b1 commit b540e3c

File tree

5 files changed

+152
-3
lines changed

5 files changed

+152
-3
lines changed

tests/test_backend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,23 @@ def func(x, y):
226226
return tf.identity(op, name=_TFOUTPUT)
227227
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
228228

229+
@check_opset_min_version(9, "ConstantOfShape")
230+
def test_layer_normalization(self):
231+
x_val = make_xval([3, 4, 5])
232+
scale_val = make_xval([3, 4, 5]) * 0.2
233+
bias_val = make_xval([3, 4, 5]) * 0.1
234+
def func(x):
235+
mean = tf.reduce_mean(x, axis=[2], keepdims=True)
236+
centered = tf.subtract(x, mean)
237+
variance = tf.add(tf.reduce_mean(tf.square(centered), axis=[2], keepdims=True), 0.001)
238+
inv_std_dev = tf.math.rsqrt(variance)
239+
normalized = tf.multiply(centered, inv_std_dev)
240+
scaled = tf.multiply(normalized, scale_val)
241+
biased = tf.add(scaled, bias_val)
242+
return tf.identity(biased, name=_TFOUTPUT)
243+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05,
244+
graph_validator=lambda g: (check_op_count(g, "InstanceNormalization", 1)))
245+
229246
@check_opset_min_version(9, "ConstantOfShape")
230247
def test_eye_non_const1(self):
231248
# tf.eye(num_rows), num_rows is not const here

tf2onnx/graph_matcher.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
class OpTypePattern(object):
2929
"""A tree pattern that matches TF expressions with certain op types."""
3030

31-
def __init__(self, op_type, name=None, inputs=None):
31+
def __init__(self, op_type, name=None, inputs=None, allow_reorder=None):
3232
"""Initializes an OpTypePattern.
3333
3434
Args:
@@ -43,9 +43,12 @@ def __init__(self, op_type, name=None, inputs=None):
4343
inputs: Optional list of `OpTypePattern`s or strings that specify the
4444
patterns for the inputs of a matching op. If None, this pattern accepts
4545
any inputs of a matching op.
46+
allow_reorder: Optional boolean that overrides allow_reorder in GraphMatcher
47+
for this pattern's immediate inputs.
4648
"""
4749
self._op_type = op_type
4850
self._name = name
51+
self.allow_reorder = allow_reorder
4952
if inputs is None:
5053
inputs = []
5154
self._inputs = [
@@ -202,7 +205,10 @@ def _match_pattern(self, pattern, op, tensor):
202205
if not op or len(op.inputs) != len(pattern.inputs):
203206
return False, match_list
204207

205-
if self._allow_reorder:
208+
allow_reorder = pattern.allow_reorder
209+
if allow_reorder is None:
210+
allow_reorder = self._allow_reorder
211+
if allow_reorder:
206212
pattern_inputs_list = permutations(pattern.inputs)
207213
else:
208214
pattern_inputs_list = [pattern.inputs]

tf2onnx/rewriter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tf2onnx.rewriter.transpose_rewriter import rewrite_transpose
2323
from tf2onnx.rewriter.conv2d_with_add_rewriter import rewrite_biasadd_with_conv2d
2424
from tf2onnx.rewriter.quantization_ops_rewriter import rewrite_quantize_and_dequantize
25+
from tf2onnx.rewriter.layer_normalization_rewriter import rewrite_layer_normalization
2526

2627

2728
__all__ = [
@@ -44,5 +45,6 @@
4445
"rewrite_custom_rnn_cell",
4546
"rewrite_generic_loop",
4647
"rewrite_biasadd_with_conv2d",
47-
"rewrite_quantize_and_dequantize"
48+
"rewrite_quantize_and_dequantize",
49+
"rewrite_layer_normalization"
4850
]
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.rewrite - Rewrites a pattern from the tf layer_norm contrib op.
6+
Converts a mean/variance normalization pattern (using ReduceMean, RSqrt, Sub, Mul, etc.) into InstanceNormalization
7+
"""
8+
from onnx import TensorProto, helper
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
from tf2onnx.graph_builder import GraphBuilder
11+
12+
13+
# pylint: disable=missing-docstring
14+
15+
def rewrite_layer_normalization(g, ops):
16+
# Needs ConstantOfShape
17+
if g.opset <= 9:
18+
return ops
19+
20+
inner_pattern = \
21+
OpTypePattern('Rsqrt', inputs=[
22+
OpTypePattern('Add', inputs=[
23+
OpTypePattern('Mean', allow_reorder=False, inputs=[
24+
OpTypePattern('Square', inputs=[
25+
OpTypePattern('Sub', allow_reorder=False, inputs=[
26+
OpTypePattern('*', name='input'),
27+
OpTypePattern('Mean', name='mean', allow_reorder=False, inputs=[
28+
OpTypePattern('*', name='input_r2'),
29+
OpTypePattern('Const|ConstV2', name='mean_axes')
30+
])
31+
])
32+
]),
33+
OpTypePattern('Const|ConstV2', name='variance_axes')
34+
]),
35+
OpTypePattern('Const|ConstV2', name='epsilon')
36+
])
37+
])
38+
39+
pattern0 = \
40+
OpTypePattern('Add', name='bias_add', inputs=[
41+
OpTypePattern('Mul', name='scale_mul', inputs=[
42+
OpTypePattern('Mul', inputs=[
43+
inner_pattern,
44+
OpTypePattern('*', name='scale')
45+
]),
46+
OpTypePattern('Sub', inputs=[
47+
OpTypePattern('*', name='input_r3'),
48+
OpTypePattern('Mean', name='mean_r2')
49+
])
50+
]),
51+
OpTypePattern('*', name='bias')
52+
])
53+
pattern1 = \
54+
OpTypePattern('Add', name='bias_add', inputs=[
55+
OpTypePattern('Mul', name='scale_mul', inputs=[
56+
OpTypePattern('Mul', inputs=[
57+
inner_pattern,
58+
OpTypePattern('Sub', inputs=[
59+
OpTypePattern('*', name='input_r3'),
60+
OpTypePattern('Mean', name='mean_r2')
61+
])
62+
]),
63+
OpTypePattern('*', name='scale')
64+
]),
65+
OpTypePattern('*', name='bias'),
66+
])
67+
pattern2 = \
68+
OpTypePattern('Add', name='bias_add', inputs=[
69+
OpTypePattern('Mul', name='scale_mul', inputs=[
70+
OpTypePattern('Mul', inputs=[
71+
OpTypePattern('*', name='scale'),
72+
OpTypePattern('Sub', inputs=[
73+
OpTypePattern('*', name='input_r3'),
74+
OpTypePattern('Mean', name='mean_r2')
75+
])
76+
]),
77+
inner_pattern
78+
]),
79+
OpTypePattern('*', name='bias'),
80+
])
81+
82+
pattern_list = [pattern0, pattern1, pattern2]
83+
84+
for pattern in pattern_list:
85+
matcher = GraphMatcher(pattern, allow_reorder=True)
86+
match_results = list(matcher.match_ops(ops))
87+
if match_results:
88+
for match in match_results:
89+
inp_node = match.get_op('input')
90+
rank = g.get_rank(inp_node.output[0])
91+
node = match.get_op('bias_add')
92+
if inp_node.name != match.get_op('input_r2').name or inp_node.name != match.get_op('input_r3').name:
93+
continue
94+
if match.get_op('mean').name != match.get_op('mean_r2').name:
95+
continue
96+
inp = match.get_op('mean').input[0]
97+
if rank != 3:
98+
continue
99+
mean_axes = match.get_op('mean_axes').get_tensor_value(as_list=True)
100+
variance_axes = match.get_op('variance_axes').get_tensor_value(as_list=True)
101+
mean_axes = [a % rank for a in mean_axes]
102+
variance_axes = [a % rank for a in variance_axes]
103+
if mean_axes != [2] or variance_axes != [2]:
104+
continue
105+
epsilon = match.get_op('epsilon').get_tensor_value(as_list=False).flatten().tolist()
106+
if len(epsilon) != 1:
107+
continue
108+
scale = match.get_op('scale').output[0]
109+
bias = match.get_op('bias').output[0]
110+
shape = g.make_node("Shape", [inp]).output[0]
111+
dim_2_shape = GraphBuilder(g).make_slice(
112+
{"data": shape, "ends": [2], "starts": [1], "axes": [0]})
113+
zero_tensor = helper.make_tensor("value", TensorProto.FLOAT, dims=[1], vals=[0])
114+
one_tensor = helper.make_tensor("value", TensorProto.FLOAT, dims=[1], vals=[1])
115+
zeros_of_shape = g.make_node("ConstantOfShape", [dim_2_shape], attr={'value': zero_tensor}).output[0]
116+
ones_of_shape = g.make_node("ConstantOfShape", [dim_2_shape], attr={'value': one_tensor}).output[0]
117+
norm = g.make_node("InstanceNormalization", [inp, ones_of_shape, zeros_of_shape],
118+
attr={'epsilon': epsilon[0]}, op_name_scope=node.name).output[0]
119+
mul = g.make_node("Mul", [norm, scale]).output[0]
120+
add = g.make_node("Add", [mul, bias]).output[0]
121+
g.replace_all_inputs(node.output[0], add)
122+
g.remove_node(node.name)
123+
return ops

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,7 @@ def compat_handler(ctx, node, **kwargs):
616616
rewrite_custom_rnn_cell,
617617
rewrite_generic_loop, rewrite_cond,
618618
rewrite_biasadd_with_conv2d,
619+
rewrite_layer_normalization,
619620
rewrite_gemm,
620621
]
621622

0 commit comments

Comments
 (0)