Skip to content

Commit b6b5d60

Browse files
authored
修复atleast函数中,”输入为tensor的list,输出不是tensor的list“的bug (#73102)
* modified for input of list of tensor in atleast func modified: python/paddle/tensor/manipulation.py modified: test/legacy_test/test_atleast_xd.py * modified tuple and any in condition modified: python/paddle/tensor/manipulation.py * added a test for tuple of tensor modified: test/legacy_test/test_atleast_xd.py * added tests of nested tuple and list modified: test/legacy_test/test_atleast_xd.py
1 parent eb2d94e commit b6b5d60

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed

python/paddle/tensor/manipulation.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5273,6 +5273,20 @@ def atleast_1d(*inputs, name=None):
52735273
[123]), Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
52745274
[[1.23000002]])]
52755275
"""
5276+
if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
5277+
if any(
5278+
isinstance(
5279+
item,
5280+
(
5281+
paddle.Tensor,
5282+
paddle.base.framework.Variable,
5283+
paddle.base.libpaddle.pir.Value,
5284+
),
5285+
)
5286+
for item in inputs[0]
5287+
):
5288+
inputs = inputs[0]
5289+
52765290
out = []
52775291
for input in inputs:
52785292
if not isinstance(
@@ -5362,6 +5376,20 @@ def atleast_2d(*inputs, name=None):
53625376
[[123]]), Tensor(shape=[1, 1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
53635377
[[[1.23000002]]])]
53645378
"""
5379+
if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
5380+
if any(
5381+
isinstance(
5382+
item,
5383+
(
5384+
paddle.Tensor,
5385+
paddle.base.framework.Variable,
5386+
paddle.base.libpaddle.pir.Value,
5387+
),
5388+
)
5389+
for item in inputs[0]
5390+
):
5391+
inputs = inputs[0]
5392+
53655393
out = []
53665394
for input in inputs:
53675395
if not isinstance(
@@ -5440,6 +5468,20 @@ def atleast_3d(*inputs, name=None):
54405468
[[[123]]]), Tensor(shape=[1, 1, 1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
54415469
[[[[1.23000002]]]])]
54425470
"""
5471+
if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
5472+
if any(
5473+
isinstance(
5474+
item,
5475+
(
5476+
paddle.Tensor,
5477+
paddle.base.framework.Variable,
5478+
paddle.base.libpaddle.pir.Value,
5479+
),
5480+
)
5481+
for item in inputs[0]
5482+
):
5483+
inputs = inputs[0]
5484+
54435485
out = []
54445486
for input in inputs:
54455487
if not isinstance(

test/legacy_test/test_atleast_xd.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,5 +495,162 @@ def test_as_tensor_method(self):
495495
np.testing.assert_allclose(n, p.numpy(), rtol=RTOL, atol=ATOL)
496496

497497

498+
class TestAtleastWithTensorList(unittest.TestCase):
499+
"""Test when input is a list of paddle tensors"""
500+
501+
def test_tensor_list_input(self):
502+
for device, place in PLACES:
503+
paddle.disable_static(place)
504+
paddle.set_device(device)
505+
506+
tensor_list = [
507+
paddle.to_tensor(123, dtype='int32'), # 0D
508+
paddle.to_tensor([1, 2, 3], dtype='float32'), # 1D
509+
paddle.to_tensor([[1, 2], [3, 4]], dtype='int64'), # 2D
510+
]
511+
512+
# atleast_1d
513+
out_1d = paddle.atleast_1d(tensor_list)
514+
self.assertTrue(isinstance(out_1d, list))
515+
self.assertEqual(len(out_1d), len(tensor_list))
516+
517+
self.assertEqual(out_1d[0].shape, [1])
518+
self.assertEqual(out_1d[1].shape, [3])
519+
self.assertEqual(out_1d[2].shape, [2, 2])
520+
521+
# atleast_2d
522+
out_2d = paddle.atleast_2d(tensor_list)
523+
self.assertTrue(isinstance(out_2d, list))
524+
self.assertEqual(len(out_2d), len(tensor_list))
525+
526+
self.assertEqual(out_2d[0].shape, [1, 1])
527+
self.assertEqual(out_2d[1].shape, [1, 3])
528+
self.assertEqual(out_2d[2].shape, [2, 2])
529+
530+
# atleast_3d
531+
out_3d = paddle.atleast_3d(tensor_list)
532+
self.assertTrue(isinstance(out_3d, list))
533+
self.assertEqual(len(out_3d), len(tensor_list))
534+
535+
self.assertEqual(out_3d[0].shape, [1, 1, 1])
536+
self.assertEqual(out_3d[1].shape, [1, 3, 1])
537+
self.assertEqual(out_3d[2].shape, [2, 2, 1])
538+
539+
np.testing.assert_allclose(
540+
out_1d[0].numpy(), [123], rtol=RTOL, atol=ATOL
541+
)
542+
np.testing.assert_allclose(
543+
out_1d[1].numpy(), [1, 2, 3], rtol=RTOL, atol=ATOL
544+
)
545+
np.testing.assert_allclose(
546+
out_1d[2].numpy(), [[1, 2], [3, 4]], rtol=RTOL, atol=ATOL
547+
)
548+
549+
550+
class TestAtleastWithTensorTuple(unittest.TestCase):
551+
"""Test when input is a tuple of paddle tensors"""
552+
553+
def test_tensor_tuple_input(self):
554+
for device, place in PLACES:
555+
paddle.disable_static(place)
556+
paddle.set_device(device)
557+
558+
tensor_tuple = (
559+
paddle.to_tensor(123, dtype='int32'), # 0D
560+
paddle.to_tensor([1, 2, 3], dtype='float32'), # 1D
561+
paddle.to_tensor([[1, 2], [3, 4]], dtype='int64'), # 2D
562+
)
563+
564+
# atleast_1d
565+
out_1d = paddle.atleast_1d(tensor_tuple)
566+
self.assertTrue(isinstance(out_1d, list))
567+
self.assertEqual(len(out_1d), len(tensor_tuple))
568+
569+
self.assertEqual(out_1d[0].shape, [1])
570+
self.assertEqual(out_1d[1].shape, [3])
571+
self.assertEqual(out_1d[2].shape, [2, 2])
572+
573+
# atleast_2d
574+
out_2d = paddle.atleast_2d(tensor_tuple)
575+
self.assertTrue(isinstance(out_2d, list))
576+
self.assertEqual(len(out_2d), len(tensor_tuple))
577+
578+
self.assertEqual(out_2d[0].shape, [1, 1])
579+
self.assertEqual(out_2d[1].shape, [1, 3])
580+
self.assertEqual(out_2d[2].shape, [2, 2])
581+
582+
# atleast_3d
583+
out_3d = paddle.atleast_3d(tensor_tuple)
584+
self.assertTrue(isinstance(out_3d, list))
585+
self.assertEqual(len(out_3d), len(tensor_tuple))
586+
587+
self.assertEqual(out_3d[0].shape, [1, 1, 1])
588+
self.assertEqual(out_3d[1].shape, [1, 3, 1])
589+
self.assertEqual(out_3d[2].shape, [2, 2, 1])
590+
591+
# Verify values are preserved correctly
592+
np.testing.assert_allclose(
593+
out_1d[0].numpy(), [123], rtol=RTOL, atol=ATOL
594+
)
595+
np.testing.assert_allclose(
596+
out_1d[1].numpy(), [1, 2, 3], rtol=RTOL, atol=ATOL
597+
)
598+
np.testing.assert_allclose(
599+
out_1d[2].numpy(), [[1, 2], [3, 4]], rtol=RTOL, atol=ATOL
600+
)
601+
602+
603+
class TestAtleastWithNestedList(unittest.TestCase):
604+
"""Test when input is a list containing nested lists"""
605+
606+
def test_nested_list_input(self):
607+
for device, place in PLACES:
608+
paddle.disable_static(place)
609+
paddle.set_device(device)
610+
611+
nested_list = [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
612+
613+
# atleast_1d
614+
out_1d = paddle.atleast_1d(nested_list)
615+
self.assertEqual(out_1d.shape, [3, 3])
616+
self.assertTrue(isinstance(out_1d, paddle.Tensor))
617+
618+
# atleast_2d
619+
out_2d = paddle.atleast_2d(nested_list)
620+
self.assertEqual(out_2d.shape, [3, 3])
621+
self.assertTrue(isinstance(out_2d, paddle.Tensor))
622+
623+
# atleast_3d
624+
out_3d = paddle.atleast_3d(nested_list)
625+
self.assertEqual(out_3d.shape, [3, 3, 1])
626+
self.assertTrue(isinstance(out_3d, paddle.Tensor))
627+
628+
629+
class TestAtleastWithNestedTuple(unittest.TestCase):
630+
"""Test when input is a tuple containing nested tuples"""
631+
632+
def test_nested_tuple_input(self):
633+
for device, place in PLACES:
634+
paddle.disable_static(place)
635+
paddle.set_device(device)
636+
637+
nested_tuple = ((1, 2, 3), (1, 2, 3), (1, 2, 3))
638+
639+
# atleast_1d
640+
out_1d = paddle.atleast_1d(nested_tuple)
641+
self.assertEqual(out_1d.shape, [3, 3])
642+
self.assertTrue(isinstance(out_1d, paddle.Tensor))
643+
644+
# atleast_2d
645+
out_2d = paddle.atleast_2d(nested_tuple)
646+
self.assertEqual(out_2d.shape, [3, 3])
647+
self.assertTrue(isinstance(out_2d, paddle.Tensor))
648+
649+
# atleast_3d
650+
out_3d = paddle.atleast_3d(nested_tuple)
651+
self.assertEqual(out_3d.shape, [3, 3, 1])
652+
self.assertTrue(isinstance(out_3d, paddle.Tensor))
653+
654+
498655
if __name__ == '__main__':
499656
unittest.main()

0 commit comments

Comments
 (0)