Skip to content

Commit 963a51e

Browse files
xadupresdpythonTomWildenhain-Microsoft
authored
Add functions to decompose einsum summation (#1472)
* Add function to decompose einsum summation Signed-off-by: xavier dupré <[email protected]> * header Signed-off-by: xavier dupré <[email protected]> * to_onnx Signed-off-by: xavier dupré <[email protected]> * fix doot Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * fix opset Signed-off-by: xavier dupré <[email protected]> * remove useless transpose Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * use gemm if possible Signed-off-by: xavier dupré <[email protected]> * add main files to handle conversion Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * complete optimizer Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * add optimization Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * einsum Signed-off-by: xavier dupré <[email protected]> * ut Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * enable einsum optimizer Signed-off-by: xavier dupré <[email protected]> * fix missing output Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> * small design change Signed-off-by: xavier dupré <[email protected]> * lint Signed-off-by: xavier dupré <[email protected]> Co-authored-by: xavier dupré <[email protected]> Co-authored-by: TomWildenhain-Microsoft <[email protected]>
1 parent 473a9ce commit 963a51e

File tree

5 files changed

+2373
-7
lines changed

5 files changed

+2373
-7
lines changed

tests/test_einsum_helper.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""Unit Tests for einsum decomposition."""
5+
6+
import unittest
7+
import itertools
8+
import numpy as np
9+
from numpy.testing import assert_almost_equal
10+
from onnxruntime import InferenceSession
11+
from tf2onnx.optimizer.einsum_optimizer import (
12+
analyse_einsum_equation, decompose_einsum_equation, EinsumSubOp)
13+
from backend_test_base import Tf2OnnxBackendTestBase
14+
from common import check_opset_min_version
15+
16+
17+
class TestEinsum(Tf2OnnxBackendTestBase):
18+
"unit tests for einsum optimizer"
19+
20+
def assert_raise(self, fct, exc_type):
21+
try:
22+
fct()
23+
except exc_type:
24+
return
25+
raise AssertionError("%r was not raised." % exc_type)
26+
27+
def apply_einsum_sequence(self, seq, *inputs):
28+
names = ["X%d" % i for i in range(len(inputs))]
29+
onx = seq.to_onnx('Y', *names, opset=self.config.opset)
30+
sess = InferenceSession(onx.SerializeToString())
31+
inps = {n: i.astype(np.float32) for n, i in zip(names, inputs)}
32+
res = sess.run(None, inps)
33+
return res[0]
34+
35+
def test_analyse_einsum_equation(self):
36+
"unit test"
37+
self.assert_raise(lambda: analyse_einsum_equation("abc"), NotImplementedError)
38+
self.assert_raise(lambda: analyse_einsum_equation("abc0,ch->ah"), ValueError)
39+
self.assert_raise(lambda: analyse_einsum_equation("abc,ch->a0"), ValueError)
40+
res = analyse_einsum_equation("abc,ch->ah")
41+
self.assertEqual(len(res), 4)
42+
letters, mat, lengths, duplicates = res
43+
self.assertEqual(letters, "abch")
44+
assert_almost_equal(lengths, np.array([3, 2, 2]))
45+
assert_almost_equal(mat, np.array([[0, 1, 2, -1], [-1, -1, 0, 1], [0, -1, -1, 1]]))
46+
self.assertEqual(duplicates, [None, None, None])
47+
48+
def test_analyse_einsum_equation_duplicates(self):
49+
res = analyse_einsum_equation("aac,ca->aa")
50+
self.assertEqual(len(res), 4)
51+
letters, mat, lengths, duplicates = res
52+
self.assertEqual(letters, "ac")
53+
assert_almost_equal(lengths, np.array([3, 2, 2]))
54+
self.assertEqual(duplicates, [{'a': [0, 1], 'c': [2]}, None, {'a': [0, 1]}])
55+
assert_almost_equal(mat, np.array([[1, 2], [1, 0], [1, -1]]))
56+
57+
@check_opset_min_version(13, "Squeeze")
58+
def test_decompose_einsum_equation(self):
59+
"test decompose einsum"
60+
m1 = np.arange(0, 8).astype(np.float32).reshape((2, 2, 2))
61+
m2 = np.arange(0, 4).astype(np.float32).reshape((2, 2))
62+
exp = np.einsum("bac,ch->ah", m1, m2)
63+
seq = decompose_einsum_equation("bac,ch->ah", (2, 2, 2), (2, 2))
64+
dot = seq.to_dot()
65+
red = dot.split('red')
66+
self.assertEqual(len(red), 5)
67+
res = self.apply_einsum_sequence(seq, m1, m2)
68+
assert_almost_equal(exp, res)
69+
70+
@check_opset_min_version(13, "Squeeze")
71+
def test_decompose_einsum_equation_deep_case(self):
72+
m1 = np.arange(0, 16).astype(np.float32).reshape((2, 2, 2, 2))
73+
m2 = np.arange(0, 16).astype(np.float32).reshape((2, 2, 2, 2))
74+
exp = np.einsum("bsnh,btnh->bnts", m1, m2)
75+
seq = decompose_einsum_equation("bsnh,btnh->bnts")
76+
res = self.apply_einsum_sequence(seq, m1, m2)
77+
assert_almost_equal(exp, res)
78+
79+
@check_opset_min_version(13, "Squeeze")
80+
def test_decompose_einsum_equation_onnx(self):
81+
m1 = np.arange(0, 24).astype(np.float32).reshape((2, 3, 4))
82+
m2 = np.arange(0, 20).astype(np.float32).reshape((4, 5))
83+
seq = decompose_einsum_equation("bac,ch->ah", (2, 3, 4), (4, 5))
84+
exp = np.einsum("bac,ch->ah", m1, m2)
85+
res = self.apply_einsum_sequence(seq, m1, m2)
86+
assert_almost_equal(exp, res)
87+
88+
@check_opset_min_version(13, "Squeeze")
89+
def test_decompose_einsum_equation_noshape(self):
90+
m1 = np.arange(0, 24).astype(np.float32).reshape((2, 3, 4))
91+
m2 = np.arange(0, 20).astype(np.float32).reshape((4, 5))
92+
seq = decompose_einsum_equation("bac,ch->ah")
93+
exp = np.einsum("bac,ch->ah", m1, m2)
94+
res = self.apply_einsum_sequence(seq, m1, m2)
95+
assert_almost_equal(exp, res)
96+
97+
@check_opset_min_version(13, "Squeeze")
98+
def test_decompose_einsum_equation_onnx2(self):
99+
"test bac,cd,def->ebc"
100+
m1 = np.arange(0, 24).astype(np.float32).reshape((2, 3, 4))
101+
m2 = np.arange(0, 20).astype(np.float32).reshape((4, 5))
102+
m3 = np.arange(0, 77 * 5).astype(np.float32).reshape((5, 7, 11))
103+
104+
seq = decompose_einsum_equation(
105+
"bac,cd,def->ebc", (2, 3, 4), (4, 5), (5, 7, 11))
106+
exp = np.einsum("bac,cd,def->ebc", m1, m2, m3)
107+
res = self.apply_einsum_sequence(seq, m1, m2, m3)
108+
assert_almost_equal(exp, res)
109+
110+
def test_einsum_sub_op(self):
111+
self.assert_raise(lambda: EinsumSubOp(2, "er", (2, 2)), ValueError)
112+
self.assert_raise(lambda: EinsumSubOp(2, "expand_dims"), RuntimeError)
113+
self.assert_raise(lambda: EinsumSubOp(2, "matmul", (2, 2)), RuntimeError)
114+
self.assert_raise(lambda: EinsumSubOp(2, "id", (2, 2)), TypeError)
115+
116+
def common_test_case_2(self, equation):
117+
m1 = np.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
118+
m2 = np.arange(4).reshape((2, 2)) + 100
119+
exp = np.einsum(equation, m1, m2)
120+
121+
seq = decompose_einsum_equation(equation, m1.shape, m2.shape)
122+
res = self.apply_einsum_sequence(seq, m1, m2)
123+
assert_almost_equal(exp, res)
124+
125+
@check_opset_min_version(13, "Squeeze")
126+
def test_case_2_a(self):
127+
self.common_test_case_2('abc,cd->abc')
128+
129+
@check_opset_min_version(13, "Squeeze")
130+
def test_many_2(self):
131+
"test many equation with 2 inputs"
132+
m1 = np.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
133+
m2 = np.arange(4).reshape((2, 2)) + 100
134+
135+
res = []
136+
for p1 in itertools.permutations(list("abc")):
137+
for p2 in itertools.permutations(list("cd")):
138+
for i in [1, 2]:
139+
for j in [0, 1]:
140+
sp1 = "".join(p1)
141+
sp2 = "".join(p2)
142+
if len(set([sp1[0], sp1[i], sp2[j]])) != 3:
143+
continue
144+
equation = "%s,%s->%s%s%s" % (
145+
sp1, sp2, sp1[0], sp1[i], sp2[j])
146+
try:
147+
r = np.einsum(equation, m1, m2)
148+
res.append((equation, r))
149+
except ValueError:
150+
# Not viable equation.
151+
continue
152+
153+
for i, (eq, exp) in enumerate(res):
154+
with self.subTest(equation=eq, index=i, total=len(res)):
155+
seq = decompose_einsum_equation(
156+
eq, m1.shape, m2.shape)
157+
res = self.apply_einsum_sequence(seq, m1, m2)
158+
exp = np.einsum(eq, m1, m2)
159+
assert_almost_equal(exp, res)
160+
161+
@check_opset_min_version(13, "Squeeze")
162+
def test_many_3(self):
163+
"test many equation with 3 inputs"
164+
m1 = np.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
165+
m2 = np.arange(4).reshape((2, 2)) + 100
166+
m3 = np.arange(8).reshape((2, 2, 2)) + 1000
167+
168+
res = []
169+
for p1 in itertools.permutations(list("abc")): # pylint: disable=R1702
170+
for p2 in itertools.permutations(list("cd")):
171+
for p3 in itertools.permutations(list("def")):
172+
for i in [1, 2]:
173+
for j in [0, 1]:
174+
sp1 = "".join(p1)
175+
sp2 = "".join(p2)
176+
sp3 = "".join(p3)
177+
equation = "%s,%s,%s->%s%s%s" % (
178+
sp1, sp2, sp3, sp1[0], sp1[i], sp3[j])
179+
try:
180+
r = np.einsum(equation, m1, m2, m3)
181+
res.append((equation, r))
182+
except ValueError:
183+
# Not viable equation.
184+
continue
185+
186+
for i, (eq, exp) in enumerate(res):
187+
with self.subTest(equation=eq, index=i, total=len(res)):
188+
seq = decompose_einsum_equation(
189+
eq, m1.shape, m2.shape, m3.shape)
190+
res = self.apply_einsum_sequence(seq, m1, m2, m3)
191+
exp = np.einsum(eq, m1, m2, m3)
192+
assert_almost_equal(exp, res)
193+
194+
# Taken from https://github.com/numpy/numpy/blob/main/numpy/
195+
# core/tests/test_einsum.py.
196+
197+
def optimize_compare(self, equation, operands=None):
198+
"Compares numpy einsum and ONNX."
199+
with self.subTest(equation=equation):
200+
if operands is not None:
201+
inputs = operands
202+
else:
203+
eqs = equation.split("->")[0].split(",")
204+
inputs = []
205+
for d, eq in enumerate(eqs):
206+
i = np.arange(2 ** len(eq)).reshape(
207+
(2,) * len(eq)).astype(np.float32)
208+
inputs.append(
209+
i + np.array([3 ** d], dtype=np.float32))
210+
211+
exp = np.einsum(equation, *inputs)
212+
shapes = [m.shape for m in inputs]
213+
214+
seq = decompose_einsum_equation(equation, *shapes)
215+
got = self.apply_einsum_sequence(seq, *inputs)
216+
assert_almost_equal(exp, got, decimal=5)
217+
218+
@check_opset_min_version(13, "Squeeze")
219+
def test_numpy_test_hadamard_like_products(self):
220+
self.optimize_compare('a,ab,abc->abc')
221+
self.optimize_compare('a,b,ab->ab')
222+
223+
@check_opset_min_version(13, "Squeeze")
224+
def test_np_test_np_test_collapse(self):
225+
self.optimize_compare('ab,ab,cd,cd->ac')
226+
self.optimize_compare('ab,ab,c->c')
227+
self.optimize_compare('ab,ab,cd,cd->cd')
228+
229+
@check_opset_min_version(13, "Squeeze")
230+
def test_np_test_index_transformations(self):
231+
self.optimize_compare('ea,fb,gc,hd,abcd->efgh')
232+
self.optimize_compare('ea,fb,abcd,gc,hd->efgh')
233+
self.optimize_compare('abcd,ea,fb,gc,hd->efgh')
234+
235+
@check_opset_min_version(13, "Squeeze")
236+
def test_np_test_expand(self):
237+
self.optimize_compare('ab,cd,ef->abcdef')
238+
self.optimize_compare('ab,cd,ef->acdf')
239+
self.optimize_compare('ab,cd,de->abcde')
240+
self.optimize_compare('ab,cd,de->be')
241+
self.optimize_compare('ab,bcd,cd->abcd')
242+
self.optimize_compare('ab,bcd,cd->abd')
243+
244+
@check_opset_min_version(13, "Squeeze")
245+
def test_np_test_edge_cases1(self):
246+
self.optimize_compare('efc,dbc,acf,fd->abe')
247+
self.optimize_compare(
248+
'eac->ace', operands=[np.arange(24).reshape((2, 3, 4))])
249+
self.optimize_compare('eac->ace')
250+
self.optimize_compare('bd,db,eac->ace')
251+
self.optimize_compare('ba,ac,da->bcd')
252+
253+
@check_opset_min_version(13, "Squeeze")
254+
def test_np_test_edge_cases2(self):
255+
self.optimize_compare(
256+
'eac->ace', operands=[np.arange(24).reshape((2, 3, 4))])
257+
self.optimize_compare('eb,cb,fb->cef')
258+
259+
@unittest.skipIf(True, "diagonal still not converted into ONNX")
260+
def test_np_test_random_cases(self):
261+
self.optimize_compare('aab,fa,df,ecc->bde')
262+
self.optimize_compare('bb,ff,be->e')
263+
self.optimize_compare('afd,ba,cc,dc->bf')
264+
self.optimize_compare('bbd,bda,fc,db->acf')
265+
self.optimize_compare('dba,ead,cad->bce')
266+
self.optimize_compare('aef,fbc,dca->bde')
267+
268+
def test_np_test_combined_views_mapping(self):
269+
a = np.arange(9).reshape(1, 1, 3, 1, 3)
270+
b = np.einsum('bbcdc->d', a)
271+
assert_almost_equal(b, [12])
272+
273+
@check_opset_min_version(13, "Squeeze")
274+
def test_np_test_broadcasting_dot_cases1(self):
275+
a = np.random.rand(1, 5, 4)
276+
b = np.random.rand(4, 6)
277+
c = np.random.rand(5, 6)
278+
d = np.random.rand(10)
279+
self.optimize_compare('ijk,kl,jl,i->i', operands=[a, b, c, d])
280+
e = np.random.rand(1, 1, 5, 4)
281+
f = np.random.rand(7, 7)
282+
self.optimize_compare('abjk,kl,jl,ab->ab', operands=[e, b, c, f])
283+
284+
@check_opset_min_version(13, "Squeeze")
285+
def test_np_test_broadcasting_dot_cases2(self):
286+
f = np.arange(7 * 55).reshape(7, 11, 5)
287+
g = np.arange(30).reshape(2, 3, 5)
288+
self.optimize_compare('obk,ijk->ioj', operands=[f, g])
289+
290+
def np_test_complex(self):
291+
self.optimize_compare('acdf,jbje,gihb,hfac,gfac,gifabc,hfac')
292+
self.optimize_compare('acdf,jbje,gihb,hfac,gfac,gifabc,hfac')
293+
self.optimize_compare('cd,bdhe,aidb,hgca,gc,hgibcd,hgac')
294+
self.optimize_compare('abhe,hidj,jgba,hiab,gab')
295+
self.optimize_compare('bde,cdh,agdb,hica,ibd,hgicd,hiac')
296+
self.optimize_compare('chd,bde,agbc,hiad,hgc,hgi,hiad')
297+
self.optimize_compare('chd,bde,agbc,hiad,bdi,cgh,agdb')
298+
self.optimize_compare('bdhe,acad,hiab,agac,hibd')
299+
300+
def np_test_inner_product(self):
301+
self.optimize_compare('ab,ab')
302+
self.optimize_compare('ab,ba')
303+
self.optimize_compare('abc,abc')
304+
self.optimize_compare('abc,bac')
305+
self.optimize_compare('abc,cba')
306+
307+
@unittest.skipIf(True, reason="diagonal still not converted into ONNX")
308+
def test_np_test_random_cases_difficult(self):
309+
"unit test"
310+
self.optimize_compare('db,bc,cfc->d')
311+
self.optimize_compare('cac,c,h->h')
312+
self.optimize_compare('cfc,c,h->h')
313+
self.optimize_compare('cfc,c,d->d')
314+
self.optimize_compare('c,cfc,d->d')
315+
self.optimize_compare('d,c,cfc->d')
316+
self.optimize_compare('d,bc,cfc->d')
317+
self.optimize_compare('adb,bc,cfc->d')
318+
self.optimize_compare('adb,bc,fa,cfc->d')
319+
self.optimize_compare('ecb,fef,bad,ed->ac')
320+
self.optimize_compare('fdf,cdd,ccd,afe->ae')
321+
self.optimize_compare('adb,cfc->d')
322+
323+
@unittest.skipIf(True, "diagonal still not converted into ONNX")
324+
def test_np_test_edge_cases_duplicate_indices(self):
325+
self.optimize_compare('dd,fb,be,cdb->cef')
326+
self.optimize_compare('dcc,fce,ea,dbf->ab')
327+
self.optimize_compare('ed,fcd,ff,bcf->be')
328+
self.optimize_compare('baa,dcf,af,cde->be')
329+
self.optimize_compare('fff,fae,bef,def->abd')
330+
331+
332+
if __name__ == "__main__":
333+
unittest.main()

0 commit comments

Comments
 (0)