Skip to content

Commit 4e1315a

Browse files
Switch Einsum to Matmul when 2nd inp is constant (#1457)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent da3d0b4 commit 4e1315a

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

tests/test_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4661,6 +4661,26 @@ def func(x, y):
46614661
return tf.identity(ret, name=_TFOUTPUT)
46624662
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
46634663

4664+
@check_opset_min_version(12)
4665+
@check_tf_min_version("2.1")
4666+
def test_einsum_to_matmul(self):
4667+
x_val = np.random.random([4, 10, 20]).astype(np.float32)
4668+
y_val = np.random.random([20, 30]).astype(np.float32)
4669+
def func(x, y):
4670+
ret = tf.einsum("bik,kj->bij", x, y)
4671+
return tf.identity(ret, name=_TFOUTPUT)
4672+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
4673+
4674+
@check_opset_min_version(12)
4675+
@check_tf_min_version("2.1")
4676+
def test_einsum_to_matmul_transpose(self):
4677+
x_val = np.random.random([4, 10, 20]).astype(np.float32)
4678+
y_val = np.random.random([30, 20]).astype(np.float32)
4679+
def func(x, y):
4680+
ret = tf.einsum("bik,jk->bij", x, y)
4681+
return tf.identity(ret, name=_TFOUTPUT)
4682+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
4683+
46644684
@check_opset_min_version(7)
46654685
def test_compare(self):
46664686
x_val = np.random.random([10, 20]).astype(np.float32)

tf2onnx/onnx_opset/math.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,43 @@ class Einsum:
580580
def version_12(cls, ctx, node, **kwargs):
581581
del node.attr["N"]
582582
node.attr["equation"].s = node.attr["equation"].s.lower()
583+
def should_replace_with_matmul():
584+
# True is 2nd inp is const and eqn is ...ik,kj->...ij (possibly transpose 2nd inp)
585+
# When the 2nd input is const, ort pre-packs the Matmul but not Einsum so this is faster
586+
eqn = node.get_attr_value("equation").decode()
587+
parts = eqn.split('->')
588+
lhs = parts[0]
589+
terms = lhs.split(',')
590+
if len(parts) >= 2:
591+
rhs = parts[1]
592+
else:
593+
rhs = sorted(terms)
594+
if len(terms) != 2:
595+
return False, None
596+
t1, t2 = terms
597+
# No repeat vars and all terms have >= 2 vars
598+
if any(len(set(t)) < len(t) or len(t) < 2 for t in [t1, t2, rhs]):
599+
return False, None
600+
if len(t2) != 2:
601+
return False, None
602+
i = rhs[-2]
603+
j = rhs[-1]
604+
if t2[0] == j:
605+
k = t2[1]
606+
transpose_t2 = True
607+
elif t2[1] == j:
608+
k = t2[0]
609+
transpose_t2 = False
610+
else:
611+
return False, None
612+
return t1.endswith(i + k) and t1[:-2] == rhs[:-2], transpose_t2
613+
should_replace, transpose_t2 = should_replace_with_matmul()
614+
if should_replace:
615+
if transpose_t2:
616+
inp_trans = ctx.make_node("Transpose", [node.input[1]], attr={'perm': [1, 0]}).output[0]
617+
ctx.replace_inputs(node, [node.input[0], inp_trans])
618+
node.type = "MatMul"
619+
del node.attr["equation"]
583620

584621

585622
@tf_op("IsFinite")

0 commit comments

Comments
 (0)