@@ -495,5 +495,162 @@ def test_as_tensor_method(self):
495
495
np .testing .assert_allclose (n , p .numpy (), rtol = RTOL , atol = ATOL )
496
496
497
497
498
+ class TestAtleastWithTensorList (unittest .TestCase ):
499
+ """Test when input is a list of paddle tensors"""
500
+
501
+ def test_tensor_list_input (self ):
502
+ for device , place in PLACES :
503
+ paddle .disable_static (place )
504
+ paddle .set_device (device )
505
+
506
+ tensor_list = [
507
+ paddle .to_tensor (123 , dtype = 'int32' ), # 0D
508
+ paddle .to_tensor ([1 , 2 , 3 ], dtype = 'float32' ), # 1D
509
+ paddle .to_tensor ([[1 , 2 ], [3 , 4 ]], dtype = 'int64' ), # 2D
510
+ ]
511
+
512
+ # atleast_1d
513
+ out_1d = paddle .atleast_1d (tensor_list )
514
+ self .assertTrue (isinstance (out_1d , list ))
515
+ self .assertEqual (len (out_1d ), len (tensor_list ))
516
+
517
+ self .assertEqual (out_1d [0 ].shape , [1 ])
518
+ self .assertEqual (out_1d [1 ].shape , [3 ])
519
+ self .assertEqual (out_1d [2 ].shape , [2 , 2 ])
520
+
521
+ # atleast_2d
522
+ out_2d = paddle .atleast_2d (tensor_list )
523
+ self .assertTrue (isinstance (out_2d , list ))
524
+ self .assertEqual (len (out_2d ), len (tensor_list ))
525
+
526
+ self .assertEqual (out_2d [0 ].shape , [1 , 1 ])
527
+ self .assertEqual (out_2d [1 ].shape , [1 , 3 ])
528
+ self .assertEqual (out_2d [2 ].shape , [2 , 2 ])
529
+
530
+ # atleast_3d
531
+ out_3d = paddle .atleast_3d (tensor_list )
532
+ self .assertTrue (isinstance (out_3d , list ))
533
+ self .assertEqual (len (out_3d ), len (tensor_list ))
534
+
535
+ self .assertEqual (out_3d [0 ].shape , [1 , 1 , 1 ])
536
+ self .assertEqual (out_3d [1 ].shape , [1 , 3 , 1 ])
537
+ self .assertEqual (out_3d [2 ].shape , [2 , 2 , 1 ])
538
+
539
+ np .testing .assert_allclose (
540
+ out_1d [0 ].numpy (), [123 ], rtol = RTOL , atol = ATOL
541
+ )
542
+ np .testing .assert_allclose (
543
+ out_1d [1 ].numpy (), [1 , 2 , 3 ], rtol = RTOL , atol = ATOL
544
+ )
545
+ np .testing .assert_allclose (
546
+ out_1d [2 ].numpy (), [[1 , 2 ], [3 , 4 ]], rtol = RTOL , atol = ATOL
547
+ )
548
+
549
+
550
+ class TestAtleastWithTensorTuple (unittest .TestCase ):
551
+ """Test when input is a tuple of paddle tensors"""
552
+
553
+ def test_tensor_tuple_input (self ):
554
+ for device , place in PLACES :
555
+ paddle .disable_static (place )
556
+ paddle .set_device (device )
557
+
558
+ tensor_tuple = (
559
+ paddle .to_tensor (123 , dtype = 'int32' ), # 0D
560
+ paddle .to_tensor ([1 , 2 , 3 ], dtype = 'float32' ), # 1D
561
+ paddle .to_tensor ([[1 , 2 ], [3 , 4 ]], dtype = 'int64' ), # 2D
562
+ )
563
+
564
+ # atleast_1d
565
+ out_1d = paddle .atleast_1d (tensor_tuple )
566
+ self .assertTrue (isinstance (out_1d , list ))
567
+ self .assertEqual (len (out_1d ), len (tensor_tuple ))
568
+
569
+ self .assertEqual (out_1d [0 ].shape , [1 ])
570
+ self .assertEqual (out_1d [1 ].shape , [3 ])
571
+ self .assertEqual (out_1d [2 ].shape , [2 , 2 ])
572
+
573
+ # atleast_2d
574
+ out_2d = paddle .atleast_2d (tensor_tuple )
575
+ self .assertTrue (isinstance (out_2d , list ))
576
+ self .assertEqual (len (out_2d ), len (tensor_tuple ))
577
+
578
+ self .assertEqual (out_2d [0 ].shape , [1 , 1 ])
579
+ self .assertEqual (out_2d [1 ].shape , [1 , 3 ])
580
+ self .assertEqual (out_2d [2 ].shape , [2 , 2 ])
581
+
582
+ # atleast_3d
583
+ out_3d = paddle .atleast_3d (tensor_tuple )
584
+ self .assertTrue (isinstance (out_3d , list ))
585
+ self .assertEqual (len (out_3d ), len (tensor_tuple ))
586
+
587
+ self .assertEqual (out_3d [0 ].shape , [1 , 1 , 1 ])
588
+ self .assertEqual (out_3d [1 ].shape , [1 , 3 , 1 ])
589
+ self .assertEqual (out_3d [2 ].shape , [2 , 2 , 1 ])
590
+
591
+ # Verify values are preserved correctly
592
+ np .testing .assert_allclose (
593
+ out_1d [0 ].numpy (), [123 ], rtol = RTOL , atol = ATOL
594
+ )
595
+ np .testing .assert_allclose (
596
+ out_1d [1 ].numpy (), [1 , 2 , 3 ], rtol = RTOL , atol = ATOL
597
+ )
598
+ np .testing .assert_allclose (
599
+ out_1d [2 ].numpy (), [[1 , 2 ], [3 , 4 ]], rtol = RTOL , atol = ATOL
600
+ )
601
+
602
+
603
+ class TestAtleastWithNestedList (unittest .TestCase ):
604
+ """Test when input is a list containing nested lists"""
605
+
606
+ def test_nested_list_input (self ):
607
+ for device , place in PLACES :
608
+ paddle .disable_static (place )
609
+ paddle .set_device (device )
610
+
611
+ nested_list = [[1 , 2 , 3 ], [1 , 2 , 3 ], [1 , 2 , 3 ]]
612
+
613
+ # atleast_1d
614
+ out_1d = paddle .atleast_1d (nested_list )
615
+ self .assertEqual (out_1d .shape , [3 , 3 ])
616
+ self .assertTrue (isinstance (out_1d , paddle .Tensor ))
617
+
618
+ # atleast_2d
619
+ out_2d = paddle .atleast_2d (nested_list )
620
+ self .assertEqual (out_2d .shape , [3 , 3 ])
621
+ self .assertTrue (isinstance (out_2d , paddle .Tensor ))
622
+
623
+ # atleast_3d
624
+ out_3d = paddle .atleast_3d (nested_list )
625
+ self .assertEqual (out_3d .shape , [3 , 3 , 1 ])
626
+ self .assertTrue (isinstance (out_3d , paddle .Tensor ))
627
+
628
+
629
+ class TestAtleastWithNestedTuple (unittest .TestCase ):
630
+ """Test when input is a tuple containing nested tuples"""
631
+
632
+ def test_nested_tuple_input (self ):
633
+ for device , place in PLACES :
634
+ paddle .disable_static (place )
635
+ paddle .set_device (device )
636
+
637
+ nested_tuple = ((1 , 2 , 3 ), (1 , 2 , 3 ), (1 , 2 , 3 ))
638
+
639
+ # atleast_1d
640
+ out_1d = paddle .atleast_1d (nested_tuple )
641
+ self .assertEqual (out_1d .shape , [3 , 3 ])
642
+ self .assertTrue (isinstance (out_1d , paddle .Tensor ))
643
+
644
+ # atleast_2d
645
+ out_2d = paddle .atleast_2d (nested_tuple )
646
+ self .assertEqual (out_2d .shape , [3 , 3 ])
647
+ self .assertTrue (isinstance (out_2d , paddle .Tensor ))
648
+
649
+ # atleast_3d
650
+ out_3d = paddle .atleast_3d (nested_tuple )
651
+ self .assertEqual (out_3d .shape , [3 , 3 , 1 ])
652
+ self .assertTrue (isinstance (out_3d , paddle .Tensor ))
653
+
654
+
498
655
if __name__ == '__main__' :
499
656
unittest .main ()
0 commit comments