Skip to content

Commit ba1b0ad

Browse files
committed
add Einsum
1 parent 9ace355 commit ba1b0ad

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

tests/test_backend.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3138,20 +3138,22 @@ def func(x):
31383138
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
31393139

31403140
@check_opset_min_version(12)
3141-
def test_less_or_equal(self):
3141+
def test_squared_distance(self):
31423142
x_val = np.random.random([4, 5]).astype(np.float32)
31433143
y_val = np.random.random([4, 5]).astype(np.float32)
31443144
def func(x, y):
3145-
return tf.math.less_equal(x, y, name=_TFOUTPUT)
3145+
return tf.math.squared_difference(x, y, name=_TFOUTPUT)
31463146
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
31473147

31483148
@check_opset_min_version(12)
3149-
def test_squared_distance(self):
3150-
x_val = np.random.random([4, 5]).astype(np.float32)
3151-
y_val = np.random.random([4, 5]).astype(np.float32)
3149+
@check_tf_min_version("2.1")
3150+
def test_einsum(self):
3151+
x_val = np.random.random([10]).astype(np.float32)
3152+
y_val = np.random.random([10]).astype(np.float32)
31523153
def func(x, y):
3153-
return tf.math.squared_difference(x, y, name=_TFOUTPUT)
3154-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
3154+
ret = tf.einsum("i,j->ij", x, y)
3155+
return tf.identity(ret, name=_TFOUTPUT)
3156+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
31553157

31563158

31573159
if __name__ == '__main__':

tf2onnx/onnx_opset/math.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,12 @@ class SquaredDistance:
561561
def version_12(cls, ctx, node, **kwargs):
562562
node.attr["reduction"] = "none"
563563

564+
@tf_op("Einsum")
565+
class Einsum:
566+
@classmethod
567+
def version_12(cls, ctx, node, **kwargs):
568+
del node.attr["N"]
569+
570+
571+
572+

0 commit comments

Comments
 (0)