Skip to content

Commit 1ce57d8

Browse files
committed
initial fix - test passes always to do to unfixed FIXME in check_op_count
1 parent 1f606c3 commit 1ce57d8

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,6 +2994,21 @@ def test_gemm_pattern3(self):
29942994
def func(a, b, c):
29952995
x_ = tf.matmul(a, b) + c
29962996
return tf.identity(x_, name=_TFOUTPUT)
2997+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2998+
graph_validator=lambda g: check_op_count(g, "Gemm", 1) )
2999+
3000+
# test for gemm pattern4: A*B + C [addbias] - 1D bias!
3001+
def test_gemm_pattern4(self):
3002+
max_number = 10
3003+
m = np.random.randint(max_number)
3004+
n = np.random.randint(max_number)
3005+
k = np.random.randint(max_number) # bias add requires 1D tensor
3006+
x_val1 = np.random.rand(m, n).astype("float32")
3007+
x_val2 = np.random.rand(n, k).astype("float32")
3008+
x_val3 = np.random.rand(k).astype("float32")
3009+
def func(a, b, c):
3010+
x_ = tf.nn.bias_add(tf.matmul(a, b), c)
3011+
return tf.identity(x_, name=_TFOUTPUT)
29973012
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
29983013
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
29993014

tf2onnx/rewriter/gemm_rewriter.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@ def rewrite_gemm(g, ops):
5555
OpTypePattern('*', name='C'),
5656
])
5757

58-
pattern_list = [pattern0, pattern1, pattern2, pattern3]
58+
# pattern4: A*B + c
59+
pattern4 = \
60+
OpTypePattern('BiasAdd', name='add', inputs=[
61+
OpTypePattern('MatMul', name='matmul'),
62+
OpTypePattern('*', name='C'),
63+
])
64+
65+
pattern_list = [pattern0, pattern1, pattern2, pattern3, pattern4]
5966

6067
for pattern in pattern_list:
6168
matcher = GraphMatcher(pattern, allow_reorder=True)

0 commit comments

Comments
 (0)