@@ -725,3 +725,163 @@ def test_roll_2d(data):
725
725
Y = dpt .roll (X , sh , ax )
726
726
Ynp = np .roll (Xnp , sh , ax )
727
727
assert_array_equal (Ynp , dpt .asnumpy (Y ))
728
+
729
+
730
+ def test_concat_incorrect_type ():
731
+ Xnp = np .ones ((2 , 2 ))
732
+ pytest .raises (TypeError , dpt .concat )
733
+ pytest .raises (TypeError , dpt .concat , [])
734
+ pytest .raises (TypeError , dpt .concat , Xnp )
735
+ pytest .raises (TypeError , dpt .concat , [Xnp , Xnp ])
736
+
737
+
738
+ def test_concat_incorrect_queue ():
739
+ try :
740
+ q1 = dpctl .SyclQueue ()
741
+ q2 = dpctl .SyclQueue ()
742
+ except dpctl .SyclQueueCreationError :
743
+ pytest .skip ("Queue could not be created" )
744
+
745
+ X = dpt .ones ((2 , 2 ), sycl_queue = q1 )
746
+ Y = dpt .ones ((2 , 2 ), sycl_queue = q2 )
747
+
748
+ pytest .raises (ValueError , dpt .concat , [X , Y ])
749
+
750
+
751
+ def test_concat_incorrect_dtype ():
752
+ try :
753
+ q = dpctl .SyclQueue ()
754
+ except dpctl .SyclQueueCreationError :
755
+ pytest .skip ("Queue could not be created" )
756
+
757
+ X = dpt .ones ((2 , 2 ), dtype = np .int64 , sycl_queue = q )
758
+ Y = dpt .ones ((2 , 2 ), dtype = np .uint64 , sycl_queue = q )
759
+
760
+ pytest .raises (ValueError , dpt .concat , [X , Y ])
761
+
762
+
763
+ def test_concat_incorrect_ndim ():
764
+ try :
765
+ q = dpctl .SyclQueue ()
766
+ except dpctl .SyclQueueCreationError :
767
+ pytest .skip ("Queue could not be created" )
768
+
769
+ X = dpt .ones ((2 , 2 ), sycl_queue = q )
770
+ Y = dpt .ones ((2 , 2 , 2 ), sycl_queue = q )
771
+
772
+ pytest .raises (ValueError , dpt .concat , [X , Y ])
773
+
774
+
775
+ @pytest .mark .parametrize (
776
+ "data" ,
777
+ [
778
+ [(2 , 2 ), (3 , 3 ), 0 ],
779
+ [(2 , 2 ), (3 , 3 ), 1 ],
780
+ [(3 , 2 ), (3 , 3 ), 0 ],
781
+ [(2 , 3 ), (3 , 3 ), 1 ],
782
+ ],
783
+ )
784
+ def test_concat_incorrect_shape (data ):
785
+ try :
786
+ q = dpctl .SyclQueue ()
787
+ except dpctl .SyclQueueCreationError :
788
+ pytest .skip ("Queue could not be created" )
789
+
790
+ Xshape , Yshape , axis = data
791
+
792
+ X = dpt .ones (Xshape , sycl_queue = q )
793
+ Y = dpt .ones (Yshape , sycl_queue = q )
794
+
795
+ pytest .raises (ValueError , dpt .concat , [X , Y ], axis )
796
+
797
+
798
+ @pytest .mark .parametrize (
799
+ "data" ,
800
+ [
801
+ [(6 ,), 0 ],
802
+ [(2 , 3 ), 1 ],
803
+ [(3 , 2 ), - 1 ],
804
+ [(1 , 6 ), 0 ],
805
+ [(2 , 1 , 3 ), 2 ],
806
+ ],
807
+ )
808
+ def test_concat_1array (data ):
809
+ try :
810
+ q = dpctl .SyclQueue ()
811
+ except dpctl .SyclQueueCreationError :
812
+ pytest .skip ("Queue could not be created" )
813
+
814
+ Xshape , axis = data
815
+
816
+ Xnp = np .arange (6 ).reshape (Xshape )
817
+ X = dpt .asarray (Xnp , sycl_queue = q )
818
+
819
+ Ynp = np .concatenate ([Xnp ], axis = axis )
820
+ Y = dpt .concat ([X ], axis = axis )
821
+
822
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
823
+
824
+ Ynp = np .concatenate ((Xnp ,), axis = axis )
825
+ Y = dpt .concat ((X ,), axis = axis )
826
+
827
+ assert_array_equal (Ynp , dpt .asnumpy (Y ))
828
+
829
+
830
+ @pytest .mark .parametrize (
831
+ "data" ,
832
+ [
833
+ [(1 ,), (1 ,), 0 ],
834
+ [(0 , 2 ), (2 , 2 ), 0 ],
835
+ [(2 , 1 ), (2 , 2 ), - 1 ],
836
+ [(2 , 2 , 2 ), (2 , 1 , 2 ), 1 ],
837
+ ],
838
+ )
839
+ def test_concat_2arrays (data ):
840
+ try :
841
+ q = dpctl .SyclQueue ()
842
+ except dpctl .SyclQueueCreationError :
843
+ pytest .skip ("Queue could not be created" )
844
+
845
+ Xshape , Yshape , axis = data
846
+
847
+ Xnp = np .ones (Xshape )
848
+ X = dpt .asarray (Xnp , sycl_queue = q )
849
+
850
+ Ynp = np .zeros (Yshape )
851
+ Y = dpt .asarray (Ynp , sycl_queue = q )
852
+
853
+ Znp = np .concatenate ([Xnp , Ynp ], axis = axis )
854
+ Z = dpt .concat ([X , Y ], axis = axis )
855
+
856
+ assert_array_equal (Znp , dpt .asnumpy (Z ))
857
+
858
+
859
+ @pytest .mark .parametrize (
860
+ "data" ,
861
+ [
862
+ [(1 ,), (1 ,), (1 ,), 0 ],
863
+ [(0 , 2 ), (2 , 2 ), (1 , 2 ), 0 ],
864
+ [(2 , 1 , 2 ), (2 , 2 , 2 ), (2 , 4 , 2 ), 1 ],
865
+ ],
866
+ )
867
+ def test_concat_3arrays (data ):
868
+ try :
869
+ q = dpctl .SyclQueue ()
870
+ except dpctl .SyclQueueCreationError :
871
+ pytest .skip ("Queue could not be created" )
872
+
873
+ Xshape , Yshape , Zshape , axis = data
874
+
875
+ Xnp = np .ones (Xshape )
876
+ X = dpt .asarray (Xnp , sycl_queue = q )
877
+
878
+ Ynp = np .zeros (Yshape )
879
+ Y = dpt .asarray (Ynp , sycl_queue = q )
880
+
881
+ Znp = np .full (Zshape , 2.0 )
882
+ Z = dpt .asarray (Znp , sycl_queue = q )
883
+
884
+ Rnp = np .concatenate ([Xnp , Ynp , Znp ], axis = axis )
885
+ R = dpt .concat ([X , Y , Z ], axis = axis )
886
+
887
+ assert_array_equal (Rnp , dpt .asnumpy (R ))
0 commit comments