@@ -756,20 +756,28 @@ def _extract_impl(ary, ary_mask, axis=0):
756
756
raise TypeError (
757
757
f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
758
758
)
759
- if not isinstance (ary_mask , dpt .usm_ndarray ):
760
- raise TypeError (
761
- f"Expecting type dpctl.tensor.usm_ndarray, got { type ( ary_mask ) } "
759
+ if isinstance (ary_mask , dpt .usm_ndarray ):
760
+ dst_usm_type = dpctl . utils . get_coerced_usm_type (
761
+ ( ary . usm_type , ary_mask . usm_type )
762
762
)
763
- dst_usm_type = dpctl .utils .get_coerced_usm_type (
764
- (ary .usm_type , ary_mask .usm_type )
765
- )
766
- exec_q = dpctl .utils .get_execution_queue (
767
- (ary .sycl_queue , ary_mask .sycl_queue )
768
- )
769
- if exec_q is None :
770
- raise dpctl .utils .ExecutionPlacementError (
771
- "arrays have different associated queues. "
772
- "Use `y.to_device(x.device)` to migrate."
763
+ exec_q = dpctl .utils .get_execution_queue (
764
+ (ary .sycl_queue , ary_mask .sycl_queue )
765
+ )
766
+ if exec_q is None :
767
+ raise dpctl .utils .ExecutionPlacementError (
768
+ "arrays have different associated queues. "
769
+ "Use `y.to_device(x.device)` to migrate."
770
+ )
771
+ elif isinstance (ary_mask , np .ndarray ):
772
+ dst_usm_type = ary .usm_type
773
+ exec_q = ary .sycl_queue
774
+ ary_mask = dpt .asarray (
775
+ ary_mask , usm_type = dst_usm_type , sycl_queue = exec_q
776
+ )
777
+ else :
778
+ raise TypeError (
779
+ "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
780
+ f"{ type (ary_mask )} "
773
781
)
774
782
ary_nd = ary .ndim
775
783
pp = normalize_axis_index (operator .index (axis ), ary_nd )
@@ -839,31 +847,32 @@ def _nonzero_impl(ary):
839
847
840
848
def _validate_indices (inds , queue_list , usm_type_list ):
841
849
"""
842
- Utility for validating indices are usm_ndarray of integral dtype or Python
843
- integers. At least one must be an array.
850
+ Utility for validating indices are NumPy ndarray or usm_ndarray of integral
851
+ dtype or Python integers. At least one must be an array.
844
852
845
853
For each array, the queue and usm type are appended to `queue_list` and
846
854
`usm_type_list`, respectively.
847
855
"""
848
- any_usmarray = False
856
+ any_array = False
849
857
for ind in inds :
850
- if isinstance (ind , dpt .usm_ndarray ):
851
- any_usmarray = True
858
+ if isinstance (ind , ( np . ndarray , dpt .usm_ndarray ) ):
859
+ any_array = True
852
860
if ind .dtype .kind not in "ui" :
853
861
raise IndexError (
854
862
"arrays used as indices must be of integer (or boolean) "
855
863
"type"
856
864
)
857
- queue_list .append (ind .sycl_queue )
858
- usm_type_list .append (ind .usm_type )
865
+ if isinstance (ind , dpt .usm_ndarray ):
866
+ queue_list .append (ind .sycl_queue )
867
+ usm_type_list .append (ind .usm_type )
859
868
elif not isinstance (ind , Integral ):
860
869
raise TypeError (
861
- "all elements of `ind` expected to be usm_ndarrays "
862
- f"or integers, found { type (ind )} "
870
+ "all elements of `ind` expected to be usm_ndarrays, "
871
+ f"NumPy arrays, or integers, found { type (ind )} "
863
872
)
864
- if not any_usmarray :
873
+ if not any_array :
865
874
raise TypeError (
866
- "at least one element of `inds` expected to be a usm_ndarray "
875
+ "at least one element of `inds` expected to be an array "
867
876
)
868
877
return inds
869
878
@@ -942,8 +951,7 @@ def _take_multi_index(ary, inds, p, mode=0):
942
951
"be associated with the same queue."
943
952
)
944
953
945
- if len (inds ) > 1 :
946
- inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
954
+ inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
947
955
948
956
ind0 = inds [0 ]
949
957
ary_sh = ary .shape
@@ -976,16 +984,28 @@ def _place_impl(ary, ary_mask, vals, axis=0):
976
984
raise TypeError (
977
985
f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
978
986
)
979
- if not isinstance (ary_mask , dpt .usm_ndarray ):
980
- raise TypeError (
981
- f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
987
+ if isinstance (ary_mask , dpt .usm_ndarray ):
988
+ exec_q = dpctl .utils .get_execution_queue (
989
+ (
990
+ ary .sycl_queue ,
991
+ ary_mask .sycl_queue ,
992
+ )
982
993
)
983
- exec_q = dpctl .utils .get_execution_queue (
984
- (
985
- ary .sycl_queue ,
986
- ary_mask .sycl_queue ,
994
+ if exec_q is None :
995
+ raise dpctl .utils .ExecutionPlacementError (
996
+ "arrays have different associated queues. "
997
+ "Use `y.to_device(x.device)` to migrate."
998
+ )
999
+ elif isinstance (ary_mask , np .ndarray ):
1000
+ exec_q = ary .sycl_queue
1001
+ ary_mask = dpt .asarray (
1002
+ ary_mask , usm_type = ary .usm_type , sycl_queue = exec_q
1003
+ )
1004
+ else :
1005
+ raise TypeError (
1006
+ "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
1007
+ f"{ type (ary_mask )} "
987
1008
)
988
- )
989
1009
if exec_q is not None :
990
1010
if not isinstance (vals , dpt .usm_ndarray ):
991
1011
vals = dpt .asarray (vals , dtype = ary .dtype , sycl_queue = exec_q )
@@ -1080,8 +1100,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
1080
1100
"be associated with the same queue."
1081
1101
)
1082
1102
1083
- if len (inds ) > 1 :
1084
- inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1103
+ inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1085
1104
1086
1105
ind0 = inds [0 ]
1087
1106
ary_sh = ary .shape
0 commit comments