Skip to content

Commit dbd3569

Browse files
JiayingGaoowayuanho
authored andcommitted
add gemm_rewriter and the corresponding test (#597)
* add gemm_rewriter and the corresponding tests * add safe_remove_nodes function in gemm_rewriter.py * Coding improvements
1 parent f593e7d commit dbd3569

File tree

5 files changed

+199
-7
lines changed

5 files changed

+199
-7
lines changed

tests/test_backend.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2571,6 +2571,80 @@ def test_selu(self):
25712571
_ = tf.identity(y, name=_TFOUTPUT)
25722572
self._run_test_case([_OUTPUT], {_INPUT: x_val})
25732573

2574+
# test for gemm pattern0: alpha*A*B + beta*C
2575+
def test_gemm_pattern0(self):
2576+
max_number = 10
2577+
m = np.random.randint(max_number)
2578+
n = np.random.randint(max_number)
2579+
k = np.random.randint(max_number)
2580+
x_val1 = np.random.rand(m, n).astype("float32")
2581+
x_val2 = np.random.rand(n, k).astype("float32")
2582+
x_val3 = np.random.rand(m, k).astype("float32")
2583+
a = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
2584+
b = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
2585+
c = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT2)
2586+
alpha = tf.constant(1.0, dtype=tf.float32)
2587+
beta = tf.constant(2.0, dtype=tf.float32)
2588+
mul1 = tf.multiply(alpha, tf.matmul(a, b))
2589+
mul2 = tf.multiply(beta, c)
2590+
x_ = mul1 + mul2
2591+
_ = tf.identity(x_, name=_TFOUTPUT)
2592+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2593+
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
2594+
2595+
# test for gemm pattern1: alpha*A*B + C
2596+
def test_gemm_pattern1(self):
2597+
max_number = 10
2598+
m = np.random.randint(max_number)
2599+
n = np.random.randint(max_number)
2600+
k = np.random.randint(max_number)
2601+
x_val1 = np.random.rand(m, n).astype("float32")
2602+
x_val2 = np.random.rand(n, k).astype("float32")
2603+
x_val3 = np.random.rand(m, k).astype("float32")
2604+
a = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
2605+
b = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
2606+
c = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT2)
2607+
alpha = tf.constant(1.0, dtype=tf.float32)
2608+
x_ = tf.multiply(alpha, tf.matmul(a, b)) + c
2609+
_ = tf.identity(x_, name=_TFOUTPUT)
2610+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2611+
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
2612+
2613+
# test for gemm pattern2: A*B + beta*C
2614+
def test_gemm_pattern2(self):
2615+
max_number = 10
2616+
m = np.random.randint(max_number)
2617+
n = np.random.randint(max_number)
2618+
k = np.random.randint(max_number)
2619+
x_val1 = np.random.rand(m, n).astype("float32")
2620+
x_val2 = np.random.rand(n, k).astype("float32")
2621+
x_val3 = np.random.rand(m, k).astype("float32")
2622+
a = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
2623+
b = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
2624+
c = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT2)
2625+
beta = tf.constant(2.0, dtype=tf.float32)
2626+
x_ = tf.matmul(a, b) + tf.multiply(beta, c)
2627+
_ = tf.identity(x_, name=_TFOUTPUT)
2628+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2629+
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
2630+
2631+
# test for gemm pattern3: A*B + C
2632+
def test_gemm_pattern3(self):
2633+
max_number = 10
2634+
m = np.random.randint(max_number)
2635+
n = np.random.randint(max_number)
2636+
k = np.random.randint(max_number)
2637+
x_val1 = np.random.rand(m, n).astype("float32")
2638+
x_val2 = np.random.rand(n, k).astype("float32")
2639+
x_val3 = np.random.rand(m, k).astype("float32")
2640+
a = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
2641+
b = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
2642+
c = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT2)
2643+
x_ = tf.matmul(a, b) + c
2644+
_ = tf.identity(x_, name=_TFOUTPUT)
2645+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2646+
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
2647+
25742648
def test_graph_matcher(self):
25752649
shape = [2, 6]
25762650
x_val = np.random.random(shape).astype(np.float32)

tf2onnx/graph_matcher.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,30 @@ def _to_pattern(self, pattern_or_name):
107107
return pattern_or_name
108108

109109
if isinstance(pattern_or_name, six.text_type):
110-
return self._name_to_pattern[pattern_or_name]
110+
return self._name_to_pattern.get(pattern_or_name)
111111

112112
raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.'
113113
% type(pattern_or_name))
114114

115-
def get_op(self, pattern_or_name):
116-
return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0]
115+
def get_op(self, pattern_or_name, default=None):
116+
"""
117+
For now, if the op can not be effectively obtained, then the function will return the default
118+
instead of an error.
119+
"""
120+
op_and_tensor = self._pattern_to_op_tensor.get(self._to_pattern(pattern_or_name))
121+
if op_and_tensor:
122+
return op_and_tensor[0]
123+
return default
117124

118-
def get_tensor(self, pattern_or_name):
119-
return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1]
125+
def get_tensor(self, pattern_or_name, default=None):
126+
"""
127+
For now, if the tensor can not be effectively obtained, then the function will return the default
128+
instead of an error.
129+
"""
130+
op_and_tensor = self._pattern_to_op_tensor.get(self._to_pattern(pattern_or_name))
131+
if op_and_tensor:
132+
return op_and_tensor[1]
133+
return default
120134

121135
def get_nodes(self):
122136
return [n[0] for n in self._pattern_to_op_tensor.values()]

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
1010
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
1111
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
12+
from tf2onnx.rewriter.gemm_rewriter import rewrite_gemm
1213
from tf2onnx.rewriter.eye_rewriter import rewrite_eye
1314
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
1415
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm, \
@@ -27,5 +28,6 @@
2728
"rewrite_single_direction_gru",
2829
"rewrite_bi_direction_gru",
2930
"rewrite_custom_rnn_cell",
31+
"rewrite_gemm",
3032
"rewrite_generic_loop"
3133
]

tf2onnx/rewriter/gemm_rewriter.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewrite - rewrite tensorflow subgraph to onnx gemm op
6+
"""
7+
import logging
8+
from onnx import onnx_pb
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
11+
# pylint: disable=missing-docstring
12+
13+
def rewrite_gemm(g, ops):
14+
if g.opset <= 6:
15+
return ops
16+
17+
# pattern0: alpha*A*B + beta*C
18+
pattern0 = \
19+
OpTypePattern('Add', name='add', inputs=[
20+
OpTypePattern('Mul', name='mul1', inputs=[
21+
OpTypePattern('Const', name='alpha'),
22+
OpTypePattern('MatMul', name='matmul')
23+
]),
24+
OpTypePattern('Mul', name='mul2', inputs=[
25+
OpTypePattern('Const', name='beta'),
26+
OpTypePattern('*', name='C')
27+
])
28+
])
29+
30+
# pattern1: alpha*A*B + C
31+
pattern1 = \
32+
OpTypePattern('Add', name='add', inputs=[
33+
OpTypePattern('Mul', name='mul1', inputs=[
34+
OpTypePattern('MatMul', name='matmul'),
35+
OpTypePattern('Const', name='alpha')
36+
]),
37+
OpTypePattern('*', name='C'),
38+
])
39+
40+
# pattern2: A*B + beta*C
41+
pattern2 = \
42+
OpTypePattern('Add', name='add', inputs=[
43+
OpTypePattern('MatMul', name='matmul'),
44+
OpTypePattern('Mul', name='mul2', inputs=[
45+
OpTypePattern('Const', name='beta'),
46+
OpTypePattern('*', name='C')
47+
])
48+
])
49+
50+
# pattern3: A*B + C
51+
pattern3 = \
52+
OpTypePattern('Add', name='add', inputs=[
53+
OpTypePattern('MatMul', name='matmul'),
54+
OpTypePattern('*', name='C'),
55+
])
56+
57+
pattern_list = [pattern0, pattern1, pattern2, pattern3]
58+
59+
for pattern in pattern_list:
60+
matcher = GraphMatcher(pattern, allow_reorder=True)
61+
match_results = list(matcher.match_ops(ops))
62+
if match_results:
63+
for match in match_results:
64+
matmul_node = match.get_op("matmul")
65+
66+
if g.get_dtype(matmul_node.input[0]) != onnx_pb.TensorProto.FLOAT:
67+
logging.warning(u"For now, onnxruntime only support float32 type for Gemm rewriter")
68+
continue
69+
70+
attr, is_valid = get_gemm_attr(match)
71+
if not is_valid:
72+
continue
73+
74+
add_node = match.get_op('add')
75+
input_c_node = match.get_op("C")
76+
a_edge_name = matmul_node.input[0]
77+
b_edge_name = matmul_node.input[1]
78+
c_edge_name = input_c_node.output[0]
79+
80+
gemm = g.make_node("Gemm", inputs=[a_edge_name, b_edge_name, c_edge_name],
81+
attr=attr,
82+
shapes=[g.get_shape(add_node.output[0])],
83+
dtypes=[g.get_dtype(add_node.output[0])])
84+
85+
ops.append(gemm)
86+
g.replace_all_inputs(ops, add_node.output[0], gemm.output[0])
87+
to_delete = [add_node, matmul_node]
88+
g.safe_remove_nodes(to_delete)
89+
return ops
90+
91+
def get_gemm_attr(match):
92+
attr = {}
93+
for arg in ["alpha", "beta"]:
94+
arg_op = match.get_op(arg)
95+
if arg_op is not None:
96+
match_args = arg_op.get_tensor_value()
97+
if isinstance(match_args, list):
98+
if len(match_args) != 1:
99+
return attr, False
100+
match_args = match_args[0]
101+
attr[arg] = match_args
102+
return attr, True

tf2onnx/tfonnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -768,13 +768,13 @@ def compat_handler(ctx, node, **kwargs):
768768

769769
# pre-processing graph rewrites
770770
# bi-directional re-writer should be placed after single directional re-writer
771-
rewriters = [rewrite_transpose, rewrite_flatten,
771+
rewriters = [rewrite_transpose, rewrite_flatten, rewrite_gemm,
772772
rewrite_random_uniform, rewrite_random_uniform_fold_const,
773773
rewrite_random_normal, rewrite_dropout, rewrite_eye,
774774
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
775775
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
776776
rewrite_single_direction_gru, rewrite_bi_direction_gru,
777-
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond
777+
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond,
778778
]
779779

780780
if custom_rewriter is not None:

0 commit comments

Comments
 (0)