Skip to content

Commit d777dbb

Browse files
authored
fix einsum squeeze (#486)
1 parent ff17c36 commit d777dbb

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

paddlenlp/ops/einsum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,5 +348,6 @@ def _mul_sum(left, right, sum_dims):
348348
squeeze_dims = [
349349
i for i in range(len(result.shape) - 1, num_output_dims - 1, -1)
350350
]
351-
result = paddle.squeeze(result, squeeze_dims)
351+
if len(squeeze_dims) != 0:
352+
result = paddle.squeeze(result, squeeze_dims)
352353
return result

tests/utils/test_ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpy as np
1515
import unittest
1616
import paddle
17-
import paddlenlp.utils.ops as ops
17+
import paddlenlp.ops as ops
1818
from common_test import CommonTest
1919

2020
EINSUM_TEST_SAMPLE = {
@@ -29,6 +29,8 @@
2929
"G": np.random.rand(4, 2, 5),
3030
"H": np.random.rand(3, 2, 4),
3131
"I": np.random.rand(2, 2),
32+
"J": np.random.rand(1, 3, 5),
33+
"K": np.random.rand(1, 2, 3, 4),
3234
}
3335

3436

@@ -171,5 +173,10 @@ def setUp(self):
171173
self.sample = {"paradigm": "ijkl, lmn->ijn", "data": ["F", "H"]}
172174

173175

176+
class TestEinsumBatch1(TestEinsum):
177+
def setUp(self):
178+
self.sample = {"paradigm": "blq,bhlk->bhlqk", "data": ["J", "K"]}
179+
180+
174181
if __name__ == "__main__":
175182
unittest.main()

0 commit comments

Comments
 (0)