Skip to content

Commit 1a35937

Browse files
Merge pull request #1069 from phager90/fix/gemmRewriter_BiasAdd
Fix/gemm rewriter bias add
2 parents ad7fe46 + b97bed0 commit 1a35937

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2997,6 +2997,21 @@ def func(a, b, c):
29972997
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
29982998
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
29992999

3000+
# test for gemm pattern4: A*B + C [addbias] - 1D bias!
3001+
def test_gemm_pattern4(self):
3002+
max_number = 10
3003+
m = np.random.randint(max_number)
3004+
n = np.random.randint(max_number)
3005+
k = np.random.randint(max_number) # bias add requires 1D tensor
3006+
x_val1 = np.random.rand(m, n).astype("float32")
3007+
x_val2 = np.random.rand(n, k).astype("float32")
3008+
x_val3 = np.random.rand(k).astype("float32")
3009+
def func(a, b, c):
3010+
x_ = tf.nn.bias_add(tf.matmul(a, b), c)
3011+
return tf.identity(x_, name=_TFOUTPUT)
3012+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
3013+
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
3014+
30003015
# test for gemm pattern0: alpha*A*B + beta*C
30013016
@check_opset_min_version(12, "Optimizer bug in ORT 1.2")
30023017
def test_gemm_pattern0_fail_broadcast(self):

tf2onnx/rewriter/gemm_rewriter.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@ def rewrite_gemm(g, ops):
5555
OpTypePattern('*', name='C'),
5656
])
5757

58-
pattern_list = [pattern0, pattern1, pattern2, pattern3]
58+
# pattern4: A*B + c
59+
pattern4 = \
60+
OpTypePattern('BiasAdd', name='add', inputs=[
61+
OpTypePattern('MatMul', name='matmul'),
62+
OpTypePattern('*', name='C'),
63+
])
64+
65+
pattern_list = [pattern0, pattern1, pattern2, pattern3, pattern4]
5966

6067
for pattern in pattern_list:
6168
matcher = GraphMatcher(pattern, allow_reorder=True)

tf2onnx/tfonnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,13 +451,13 @@ def compat_handler(ctx, node, **kwargs):
451451
# pre-processing graph rewrites
452452
# bi-directional re-writer should be placed after single directional re-writer
453453
rewriters = [rewrite_constant_fold, rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten,
454-
rewrite_gemm, rewrite_random_uniform, rewrite_random_uniform_fold_const,
454+
rewrite_random_uniform, rewrite_random_uniform_fold_const,
455455
rewrite_random_normal, rewrite_dropout, rewrite_eye,
456456
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
457457
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
458458
rewrite_single_direction_gru, rewrite_bi_direction_gru,
459459
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond,
460-
rewrite_biasadd_with_conv2d,
460+
rewrite_biasadd_with_conv2d, rewrite_gemm
461461
]
462462

463463
if custom_rewriter is not None:

0 commit comments

Comments
 (0)