@@ -535,6 +535,21 @@ def test_put_basic_axis():
535
535
assert (expected == dpt .asnumpy (x )).all ()
536
536
537
537
538
+ @pytest .mark .parametrize ("data_dt" , _all_dtypes )
539
+ def test_put_0d_val (data_dt ):
540
+ q = get_queue_or_skip ()
541
+ skip_if_dtype_not_supported (data_dt , q )
542
+
543
+ x = dpt .arange (5 , dtype = data_dt , sycl_queue = q )
544
+ ind = dpt .asarray ([0 ], dtype = np .intp , sycl_queue = q )
545
+ x [ind ] = 2
546
+ assert_array_equal (np .asarray (2 , dtype = data_dt ), dpt .asnumpy (x [0 ]))
547
+
548
+ x = dpt .asarray (5 , dtype = data_dt , sycl_queue = q )
549
+ x [ind ] = 2
550
+ assert_array_equal (np .asarray (2 , dtype = data_dt ), dpt .asnumpy (x ))
551
+
552
+
538
553
@pytest .mark .parametrize (
539
554
"data_dt" ,
540
555
_all_dtypes ,
@@ -543,8 +558,8 @@ def test_take_0d_data(data_dt):
543
558
q = get_queue_or_skip ()
544
559
skip_if_dtype_not_supported (data_dt , q )
545
560
546
- x = dpt .asarray (0 , dtype = data_dt )
547
- ind = dpt .arange (5 )
561
+ x = dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
562
+ ind = dpt .arange (5 , dtype = np . intp , sycl_queue = q )
548
563
549
564
y = dpt .take (x , ind )
550
565
assert (
@@ -561,9 +576,9 @@ def test_put_0d_data(data_dt):
561
576
q = get_queue_or_skip ()
562
577
skip_if_dtype_not_supported (data_dt , q )
563
578
564
- x = dpt .asarray (0 , dtype = data_dt )
565
- ind = dpt .arange (5 )
566
- val = dpt .asarray (2 , dtype = data_dt )
579
+ x = dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
580
+ ind = dpt .arange (5 , dtype = np . intp , sycl_queue = q )
581
+ val = dpt .asarray (2 , dtype = data_dt , sycl_queue = q )
567
582
568
583
dpt .put (x , ind , val , axis = 0 )
569
584
assert (
@@ -577,10 +592,10 @@ def test_put_0d_data(data_dt):
577
592
_all_int_dtypes ,
578
593
)
579
594
def test_take_0d_ind (ind_dt ):
580
- get_queue_or_skip ()
595
+ q = get_queue_or_skip ()
581
596
582
- x = dpt .arange (5 , dtype = ind_dt )
583
- ind = dpt .asarray (3 )
597
+ x = dpt .arange (5 , dtype = "i4" , sycl_queue = q )
598
+ ind = dpt .asarray (3 , dtype = ind_dt , sycl_queue = q )
584
599
585
600
y = dpt .take (x , ind )
586
601
assert dpt .asnumpy (x [3 ]) == dpt .asnumpy (y )
@@ -591,11 +606,11 @@ def test_take_0d_ind(ind_dt):
591
606
_all_int_dtypes ,
592
607
)
593
608
def test_put_0d_ind (ind_dt ):
594
- get_queue_or_skip ()
609
+ q = get_queue_or_skip ()
595
610
596
- x = dpt .arange (5 , dtype = ind_dt )
597
- ind = dpt .asarray (3 )
598
- val = dpt .asarray (5 , dtype = ind_dt )
611
+ x = dpt .arange (5 , dtype = "i4" , sycl_queue = q )
612
+ ind = dpt .asarray (3 , dtype = ind_dt , sycl_queue = q )
613
+ val = dpt .asarray (5 , dtype = x . dtype , sycl_queue = q )
599
614
600
615
dpt .put (x , ind , val , axis = 0 )
601
616
assert dpt .asnumpy (x [3 ]) == dpt .asnumpy (val )
@@ -750,7 +765,7 @@ def test_put_strided_1d_destination(data_dt, order):
750
765
751
766
x = dpt .arange (27 , dtype = data_dt , sycl_queue = q )
752
767
ind = dpt .arange (4 , 9 , dtype = np .intp , sycl_queue = q )
753
- val = dpt .asarray (9 , dtype = data_dt , sycl_queue = q )
768
+ val = dpt .asarray (9 , dtype = x . dtype , sycl_queue = q )
754
769
755
770
x_np = dpt .asnumpy (x )
756
771
ind_np = dpt .asnumpy (ind )
@@ -780,7 +795,7 @@ def test_put_strided_destination(data_dt, order):
780
795
781
796
x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
782
797
ind = dpt .arange (2 , dtype = np .intp , sycl_queue = q )
783
- val = dpt .asarray (9 , dtype = data_dt , sycl_queue = q )
798
+ val = dpt .asarray (9 , dtype = x . dtype , sycl_queue = q )
784
799
785
800
x_np = dpt .asnumpy (x )
786
801
ind_np = dpt .asnumpy (ind )
@@ -825,7 +840,7 @@ def test_put_strided_1d_indices(ind_dt):
825
840
826
841
x = dpt .arange (27 , dtype = "i4" , sycl_queue = q )
827
842
ind = dpt .arange (12 , 24 , dtype = ind_dt , sycl_queue = q )
828
- val = dpt .asarray (- 1 , dtype = "i4" , sycl_queue = q )
843
+ val = dpt .asarray (- 1 , dtype = x . dtype , sycl_queue = q )
829
844
830
845
x_np = dpt .asnumpy (x )
831
846
ind_np = dpt .asnumpy (ind ).astype (np .intp )
@@ -880,43 +895,53 @@ def test_put_strided_indices(ind_dt, order):
880
895
881
896
882
897
def test_take_arg_validation ():
883
- get_queue_or_skip ()
898
+ q = get_queue_or_skip ()
884
899
885
- x = dpt .arange (4 )
886
- ind0 = dpt .arange (2 )
887
- ind1 = dpt .arange (2.0 )
900
+ x = dpt .arange (4 , dtype = "i4" , sycl_queue = q )
901
+ ind0 = dpt .arange (2 , dtype = np . intp , sycl_queue = q )
902
+ ind1 = dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
888
903
889
- with pytest .raises (ValueError ):
890
- dpt .take (dpt .reshape (x , (2 , 2 )), ind0 )
891
904
with pytest .raises (TypeError ):
892
905
dpt .take (dict (), ind0 , axis = 0 )
893
906
with pytest .raises (TypeError ):
894
907
dpt .take (x , dict (), axis = 0 )
895
908
with pytest .raises (TypeError ):
909
+ x [[]]
910
+ with pytest .raises (IndexError ):
896
911
dpt .take (x , ind1 , axis = 0 )
912
+ with pytest .raises (IndexError ):
913
+ x [ind1 ]
897
914
915
+ with pytest .raises (ValueError ):
916
+ dpt .take (dpt .reshape (x , (2 , 2 )), ind0 )
898
917
with pytest .raises (ValueError ):
899
918
dpt .take (x , ind0 , mode = 0 )
900
919
with pytest .raises (ValueError ):
901
920
dpt .take (dpt .reshape (x , (2 , 2 )), ind0 , axis = None )
902
921
903
922
904
923
def test_put_arg_validation ():
905
- get_queue_or_skip ()
924
+ q = get_queue_or_skip ()
906
925
907
- x = dpt .arange (4 )
908
- ind0 = dpt .arange (2 )
909
- ind1 = dpt .arange (2.0 )
910
- val = dpt .asarray (2 )
926
+ x = dpt .arange (4 , dtype = "i4" , sycl_queue = q )
927
+ ind0 = dpt .arange (2 , dtype = np . intp , sycl_queue = q )
928
+ ind1 = dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
929
+ val = dpt .asarray (2 , x . dtype , sycl_queue = q )
911
930
912
931
with pytest .raises (TypeError ):
913
932
dpt .put (dict (), ind0 , val , axis = 0 )
914
933
with pytest .raises (TypeError ):
915
934
dpt .put (x , dict (), val , axis = 0 )
916
935
with pytest .raises (TypeError ):
936
+ x [[]] = val
937
+ with pytest .raises (IndexError ):
917
938
dpt .put (x , ind1 , val , axis = 0 )
939
+ with pytest .raises (IndexError ):
940
+ x [ind1 ] = val
918
941
with pytest .raises (TypeError ):
919
942
dpt .put (x , ind0 , dict (), axis = 0 )
943
+ with pytest .raises (TypeError ):
944
+ x [ind0 ] = dict ()
920
945
921
946
with pytest .raises (ValueError ):
922
947
dpt .put (x , ind0 , val , mode = 0 )
0 commit comments