Skip to content

Commit e04e29d

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into bench
2 parents 1b381dd + d4be5af commit e04e29d

File tree

4 files changed

+686
-57
lines changed

4 files changed

+686
-57
lines changed

tests/test_einsum_helper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ def test_np_test_edge_cases_duplicate_indices(self):
328328
self.optimize_compare('baa,dcf,af,cde->be')
329329
self.optimize_compare('fff,fae,bef,def->abd')
330330

331+
def test_abbba(self):
332+
decompose_einsum_equation("ab,b->ba")
333+
331334

332335
if __name__ == "__main__":
333336
unittest.main()

tests/test_einsum_ml.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""Unit Tests for einsum decomposition."""
5+
6+
import unittest
7+
from itertools import permutations
8+
import numpy as np
9+
from numpy.testing import assert_almost_equal
10+
from onnx import helper, TensorProto, numpy_helper
11+
from tf2onnx.optimizer.einsum_optimizer import (
12+
OnnxMicroRuntime,
13+
predict_transposition_cost,
14+
compute_transposition_features)
15+
from tf2onnx import constants
16+
from backend_test_base import Tf2OnnxBackendTestBase
17+
18+
19+
class TestEinsumMl(Tf2OnnxBackendTestBase):
20+
"unit tests for einsum optimizer"
21+
22+
def test_onnx_micro_runtime(self):
23+
"test OnnxMicroRuntime"
24+
opset = self.config.opset
25+
x = np.array([1, 2, 4, 5, 5, 4]).astype(
26+
np.float32).reshape((3, 2))
27+
28+
model_def = helper.make_model(
29+
opset_imports=[helper.make_operatorsetid('', opset)],
30+
ir_version=constants.OPSET_TO_IR_VERSION[opset],
31+
producer_name='tf2onnx',
32+
producer_version='0.0.1',
33+
graph=helper.make_graph(
34+
name='einsum',
35+
inputs=[helper.make_tensor_value_info('X', TensorProto.FLOAT, None)],
36+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
37+
nodes=[
38+
helper.make_node('Add', ["X", "X"], ["temp"]),
39+
helper.make_node('Add', ["X", "temp"], ["Y"]),
40+
]))
41+
42+
rt = OnnxMicroRuntime(model_def)
43+
out = rt.run({'X': x})
44+
self.assertIn('X', out)
45+
self.assertIn('Y', out)
46+
self.assertIn('temp', out)
47+
self.assertEqual(len(out), 3)
48+
49+
def test_onnx_micro_runtime_exc1(self):
50+
"test OnnxMicroRuntime"
51+
with self.assertRaises(TypeError):
52+
OnnxMicroRuntime(None)
53+
54+
def test_onnx_micro_runtime_exc2(self):
55+
"test OnnxMicroRuntime"
56+
opset = self.config.opset
57+
x = np.array([1, 2, 4, 5, 5, 4]).astype(
58+
np.float32).reshape((3, 2))
59+
60+
model_def = helper.make_model(
61+
opset_imports=[helper.make_operatorsetid('', opset)],
62+
ir_version=constants.OPSET_TO_IR_VERSION[opset],
63+
producer_name='tf2onnx',
64+
producer_version='0.0.1',
65+
graph=helper.make_graph(
66+
name='einsum',
67+
inputs=[helper.make_tensor_value_info('X', TensorProto.FLOAT, None)],
68+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
69+
initializer=[
70+
numpy_helper.from_array(np.array([1], dtype=np.float32), name="C1"),
71+
numpy_helper.from_array(np.array([2], dtype=np.float32), name="C2"),
72+
],
73+
nodes=[
74+
helper.make_node('Add', ["X", "C1"], ["temp"]),
75+
helper.make_node('Pow', ["temp", "C2"], ["Y"]),
76+
]))
77+
78+
rt = OnnxMicroRuntime(model_def)
79+
with self.assertRaises(NotImplementedError):
80+
rt.run({'X': x})
81+
with self.assertRaises(TypeError):
82+
rt.run(x)
83+
84+
def test_onnx_micro_runtime_shape(self):
85+
"test OnnxMicroRuntime"
86+
opset = self.config.opset
87+
x = np.array([1, 2, 4, 5, 5, 4]).astype(
88+
np.float32).reshape((3, 2))
89+
90+
model_def = helper.make_model(
91+
opset_imports=[helper.make_operatorsetid('', opset)],
92+
ir_version=constants.OPSET_TO_IR_VERSION[opset],
93+
producer_name='tf2onnx',
94+
producer_version='0.0.1',
95+
graph=helper.make_graph(
96+
name='einsum',
97+
inputs=[helper.make_tensor_value_info('X', TensorProto.FLOAT, None)],
98+
outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, None)],
99+
nodes=[
100+
helper.make_node('Shape', ["X"], ["Y"]),
101+
]))
102+
103+
rt = OnnxMicroRuntime(model_def)
104+
out = rt.run({'X': x})
105+
assert_almost_equal(np.array(x.shape, dtype=np.int64), out['Y'])
106+
107+
def test_onnx_micro_runtime_unsqueeze(self):
108+
"test OnnxMicroRuntime"
109+
opset = self.config.opset
110+
x = np.array([1, 2, 4, 5, 5, 4]).astype(
111+
np.float32).reshape((3, 2))
112+
i = np.array([1]).astype(np.int64)
113+
114+
model_def = helper.make_model(
115+
opset_imports=[helper.make_operatorsetid('', opset)],
116+
ir_version=constants.OPSET_TO_IR_VERSION[opset],
117+
producer_name='tf2onnx',
118+
producer_version='0.0.1',
119+
graph=helper.make_graph(
120+
name='einsum',
121+
inputs=[helper.make_tensor_value_info('X', TensorProto.FLOAT, None),
122+
helper.make_tensor_value_info('I', TensorProto.INT64, None)],
123+
outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, None)],
124+
nodes=[
125+
helper.make_node('Unsqueeze', ["X", "I"], ["Y"]),
126+
]))
127+
128+
rt = OnnxMicroRuntime(model_def)
129+
out = rt.run({'X': x, 'I': i})
130+
assert_almost_equal(np.array(x.reshape((3, 1, 2))), out['Y'])
131+
132+
def test_onnx_micro_runtime_transpose(self):
133+
"test OnnxMicroRuntime"
134+
opset = self.config.opset
135+
x = np.array([1, 2, 4, 5, 5, 4]).astype(
136+
np.float32).reshape((3, 2))
137+
138+
model_def = helper.make_model(
139+
opset_imports=[helper.make_operatorsetid('', opset)],
140+
ir_version=constants.OPSET_TO_IR_VERSION[opset],
141+
producer_name='tf2onnx',
142+
producer_version='0.0.1',
143+
graph=helper.make_graph(
144+
name='einsum',
145+
inputs=[helper.make_tensor_value_info('X', TensorProto.FLOAT, None)],
146+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
147+
nodes=[
148+
helper.make_node('Transpose', ["X"], ["Y"], perm=[1, 0]),
149+
]))
150+
151+
rt = OnnxMicroRuntime(model_def)
152+
out = rt.run({'X': x})
153+
assert_almost_equal(x.T, out['Y'])
154+
155+
def test_onnx_micro_runtime_matmul(self):
156+
"test OnnxMicroRuntime"
157+
opset = self.config.opset
158+
x = np.array([1, 2, 4, 5]).astype(
159+
np.float32).reshape((2, 2))
160+
161+
model_def = helper.make_model(
162+
opset_imports=[helper.make_operatorsetid('', opset)],
163+
ir_version=constants.OPSET_TO_IR_VERSION[opset],
164+
producer_name='tf2onnx',
165+
producer_version='0.0.1',
166+
graph=helper.make_graph(
167+
name='einsum',
168+
inputs=[helper.make_tensor_value_info('X', TensorProto.FLOAT, None)],
169+
outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, None)],
170+
nodes=[
171+
helper.make_node('MatMul', ["X", "X"], ["Y"]),
172+
]))
173+
174+
rt = OnnxMicroRuntime(model_def)
175+
out = rt.run({'X': x})
176+
assert_almost_equal(np.matmul(x, x), out['Y'])
177+
178+
def test_features(self):
179+
res = compute_transposition_features((3, 5, 7), (0, 1, 2))
180+
self.assertIsInstance(res, dict)
181+
self.assertEqual(res["edit"], 0)
182+
self.assertEqual(res["rot"], -1)
183+
res = compute_transposition_features((3, 5, 7), (2, 1, 0))
184+
self.assertEqual(res["edit"], 2)
185+
self.assertEqual(res["rot"], 0)
186+
self.assertEqual(res["rev"], 1)
187+
188+
def test_cost(self):
189+
res = predict_transposition_cost((300, 500, 700), (0, 1, 2))
190+
self.assertIsInstance(res, float)
191+
self.assertGreater(res, 0)
192+
for shape in [(3, 5, 7), (30, 50, 70)]:
193+
for perm in permutations([0, 1, 2]):
194+
p = tuple(perm)
195+
cost = predict_transposition_cost(shape, p)
196+
if p[-1] == 2:
197+
self.assertEqual(cost, 0)
198+
199+
200+
if __name__ == "__main__":
201+
unittest.main()

tests/test_einsum_optimizers.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,23 @@ def common_einsum(self, equation, operands=None, catch_errors=True):
106106
def test_np_test_broadcasting_dot_cases2(self):
107107
f = np.arange(7 * 55).reshape(7, 11, 5).astype(np.float32)
108108
g = np.arange(30).reshape(2, 3, 5).astype(np.float32)
109-
self.common_einsum('obk,ijk->ioj', operands=[f, g],
110-
catch_errors=False)
109+
self.common_einsum('obk,ijk->ioj', operands=[f, g], catch_errors=True)
110+
111+
@check_opset_min_version(13, "Unsqueeze")
112+
def test_np_test_broadcasting_double_transpose(self):
113+
f = np.arange(10).reshape(2, 5).astype(np.float32)
114+
g = np.arange(5).astype(np.float32)
115+
self.common_einsum('ab,b->ab', operands=[f, g], catch_errors=True)
116+
self.common_einsum('ab,b->ba', operands=[f, g], catch_errors=True)
117+
118+
@check_opset_min_version(13, "Unsqueeze")
119+
def test_np_test_broadcasting_dot_cases3(self):
120+
f = np.arange(12).reshape(2, 3, 2).astype(np.float32)
121+
g = np.arange(6).reshape(2, 3).astype(np.float32)
122+
self.common_einsum('abi,ab->ab', operands=[f, g], catch_errors=True)
123+
f = np.arange(12).reshape(2, 3, 2).astype(np.float32)
124+
g = np.arange(12).reshape(2, 3, 2).astype(np.float32)
125+
self.common_einsum('abi,abq->abq', operands=[f, g], catch_errors=True)
111126

112127
@check_opset_min_version(12, "Einsum")
113128
def test_np_test_exp(self):

0 commit comments

Comments
 (0)