Skip to content

Commit 59fed17

Browse files
authored
Merge pull request #906 from jignparm/jignparm/gemm_broadcast
Fix GEMM to check for shape broadcast compatibility of A*B and C
2 parents 7c37ccb + cffe8c5 commit 59fed17

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,6 +2859,31 @@ def func(a, b, c):
28592859
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
28602860
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
28612861

2862+
# test for gemm pattern0: alpha*A*B + beta*C
2863+
@check_opset_min_version(12, "Optimizer bug in ORT 1.2")
2864+
def test_gemm_pattern0_fail_broadcast(self):
2865+
# shapes (3, 3) * (3, 1) + (1, 4) => (3, 1) + (1, 4)
2866+
# c not uni-broadcastable to a * b, so should not use GEMM
2867+
m, n, k = 3, 3, 1
2868+
x_val1 = np.random.rand(m, n).astype("float32")
2869+
x_val2 = np.random.rand(n, k).astype("float32")
2870+
x_val3 = np.random.rand(k, 4).astype("float32")
2871+
2872+
def func(a, b, c):
2873+
alpha = tf.constant(1.0, dtype=tf.float32)
2874+
beta = tf.constant(2.0, dtype=tf.float32)
2875+
mul1 = tf.multiply(alpha, tf.matmul(a, b))
2876+
mul2 = tf.multiply(beta, c)
2877+
x_ = mul1 + mul2
2878+
return tf.identity(x_, name=_TFOUTPUT)
2879+
2880+
def graph_validator(g):
2881+
if 'Gemm' in [n.type for n in g.get_nodes()]: return False
2882+
return True
2883+
2884+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2885+
graph_validator=graph_validator)
2886+
28622887
def test_graph_matcher(self):
28632888
shape = [2, 6]
28642889
x_val = np.random.random(shape).astype(np.float32)

tf2onnx/rewriter/gemm_rewriter.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from onnx import onnx_pb
99
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
1010

11+
1112
# pylint: disable=missing-docstring
1213

1314
def rewrite_gemm(g, ops):
@@ -77,17 +78,29 @@ def rewrite_gemm(g, ops):
7778
b_edge_name = matmul_node.input[1]
7879
c_edge_name = input_c_node.output[0]
7980

81+
a_mul_b_shape = g.get_shape(matmul_node.output[0])
82+
c_shape = g.get_shape(c_edge_name)
83+
if c_shape is None: continue
84+
if a_mul_b_shape is None: continue
85+
if -1 in c_shape + a_mul_b_shape: continue
86+
compatible = True
87+
for i in range(1, len(c_shape) + 1):
88+
if c_shape[-i] not in [1, a_mul_b_shape[-i]]:
89+
compatible = False
90+
if not compatible: continue
91+
8092
gemm = g.make_node("Gemm", inputs=[a_edge_name, b_edge_name, c_edge_name],
8193
attr=attr,
8294
shapes=[g.get_shape(add_node.output[0])],
83-
dtypes=[g.get_dtype(add_node.output[0])])
95+
dtypes=[g.get_dtype(add_node.output[0])], op_name_scope=matmul_node.name)
8496

8597
ops.append(gemm)
8698
g.replace_all_inputs(ops, add_node.output[0], gemm.output[0])
8799
to_delete = [add_node, matmul_node]
88100
g.safe_remove_nodes(to_delete)
89101
return ops
90102

103+
91104
def get_gemm_attr(match):
92105
attr = {}
93106
for arg in ["alpha", "beta"]:

0 commit comments

Comments
 (0)