40
40
isdtype = get_xp (np )(_aliases .isdtype )
41
41
unstack = get_xp (da )(_aliases .unstack )
42
42
43
+ # da.astype doesn't respect copy=True
43
44
def astype (
44
45
x : Array ,
45
46
dtype : Dtype ,
46
47
/ ,
47
48
* ,
48
49
copy : bool = True ,
49
- device : Device | None = None
50
+ device : Optional [ Device ] = None
50
51
) -> Array :
52
+ """
53
+ Array API compatibility wrapper for astype().
54
+
55
+ See the corresponding documentation in the array library and/or the array API
56
+ specification for more details.
57
+ """
51
58
# TODO: respect device keyword?
59
+
52
60
if not copy and dtype == x .dtype :
53
61
return x
54
- # dask astype doesn't respect copy=True,
55
- # so call copy manually afterwards
56
62
x = x .astype (dtype )
57
63
return x .copy () if copy else x
58
64
@@ -61,20 +67,24 @@ def astype(
61
67
# This arange func is modified from the common one to
62
68
# not pass stop/step as keyword arguments, which will cause
63
69
# an error with dask
64
-
65
- # TODO: delete the xp stuff, it shouldn't be necessary
66
- def _dask_arange (
70
+ def arange (
67
71
start : Union [int , float ],
68
72
/ ,
69
73
stop : Optional [Union [int , float ]] = None ,
70
74
step : Union [int , float ] = 1 ,
71
75
* ,
72
- xp ,
73
76
dtype : Optional [Dtype ] = None ,
74
77
device : Optional [Device ] = None ,
75
78
** kwargs ,
76
79
) -> Array :
77
- _check_device (xp , device )
80
+ """
81
+ Array API compatibility wrapper for arange().
82
+
83
+ See the corresponding documentation in the array library and/or the array API
84
+ specification for more details.
85
+ """
86
+ # TODO: respect device keyword?
87
+
78
88
args = [start ]
79
89
if stop is not None :
80
90
args .append (stop )
@@ -83,13 +93,12 @@ def _dask_arange(
83
93
# prepend the default value for start which is 0
84
94
args .insert (0 , 0 )
85
95
args .append (step )
86
- return xp .arange (* args , dtype = dtype , ** kwargs )
87
96
88
- arange = get_xp ( da )( _dask_arange )
89
- eye = get_xp ( da )( _aliases . eye )
97
+ return da . arange ( * args , dtype = dtype , ** kwargs )
98
+
90
99
91
- linspace = get_xp (da )(_aliases .linspace )
92
100
eye = get_xp (da )(_aliases .eye )
101
+ linspace = get_xp (da )(_aliases .linspace )
93
102
UniqueAllResult = get_xp (da )(_aliases .UniqueAllResult )
94
103
UniqueCountsResult = get_xp (da )(_aliases .UniqueCountsResult )
95
104
UniqueInverseResult = get_xp (da )(_aliases .UniqueInverseResult )
@@ -112,7 +121,6 @@ def _dask_arange(
112
121
reshape = get_xp (da )(_aliases .reshape )
113
122
matrix_transpose = get_xp (da )(_aliases .matrix_transpose )
114
123
vecdot = get_xp (da )(_aliases .vecdot )
115
-
116
124
nonzero = get_xp (da )(_aliases .nonzero )
117
125
ceil = get_xp (np )(_aliases .ceil )
118
126
floor = get_xp (np )(_aliases .floor )
@@ -121,6 +129,7 @@ def _dask_arange(
121
129
tensordot = get_xp (np )(_aliases .tensordot )
122
130
sign = get_xp (np )(_aliases .sign )
123
131
132
+
124
133
# asarray also adds the copy keyword, which is not present in numpy 1.0.
125
134
def asarray (
126
135
obj : Union [
@@ -135,7 +144,7 @@ def asarray(
135
144
* ,
136
145
dtype : Optional [Dtype ] = None ,
137
146
device : Optional [Device ] = None ,
138
- copy : " Optional[Union[bool, np._CopyMode]]" = None ,
147
+ copy : Optional [Union [bool , np ._CopyMode ]] = None ,
139
148
** kwargs ,
140
149
) -> Array :
141
150
"""
@@ -144,6 +153,8 @@ def asarray(
144
153
See the corresponding documentation in the array library and/or the array API
145
154
specification for more details.
146
155
"""
156
+ # TODO: respect device keyword?
157
+
147
158
if isinstance (obj , da .Array ):
148
159
if dtype is not None and dtype != obj .dtype :
149
160
if copy is False :
@@ -183,15 +194,18 @@ def asarray(
183
194
# Furthermore, the masking workaround in common._aliases.clip cannot work with
184
195
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
185
196
# now).
186
- @get_xp (da )
187
197
def clip (
188
198
x : Array ,
189
199
/ ,
190
200
min : Optional [Union [int , float , Array ]] = None ,
191
201
max : Optional [Union [int , float , Array ]] = None ,
192
- * ,
193
- xp ,
194
202
) -> Array :
203
+ """
204
+ Array API compatibility wrapper for clip().
205
+
206
+ See the corresponding documentation in the array library and/or the array API
207
+ specification for more details.
208
+ """
195
209
def _isscalar (a ):
196
210
return isinstance (a , (int , float , type (None )))
197
211
min_shape = () if _isscalar (min ) else min .shape
@@ -201,19 +215,19 @@ def _isscalar(a):
201
215
result_shape = np .broadcast_shapes (x .shape , min_shape , max_shape )
202
216
203
217
if min is not None :
204
- min = xp .broadcast_to (xp .asarray (min ), result_shape )
218
+ min = da .broadcast_to (da .asarray (min ), result_shape )
205
219
if max is not None :
206
- max = xp .broadcast_to (xp .asarray (max ), result_shape )
220
+ max = da .broadcast_to (da .asarray (max ), result_shape )
207
221
208
222
if min is None and max is None :
209
- return xp .positive (x )
223
+ return da .positive (x )
210
224
211
225
if min is None :
212
- return astype (xp .minimum (x , max ), x .dtype )
226
+ return astype (da .minimum (x , max ), x .dtype )
213
227
if max is None :
214
- return astype (xp .maximum (x , min ), x .dtype )
228
+ return astype (da .maximum (x , min ), x .dtype )
215
229
216
- return astype (xp .minimum (xp .maximum (x , min ), max ), x .dtype )
230
+ return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
217
231
218
232
# exclude these from all since dask.array has no sorting functions
219
233
_da_unsupported = ['sort' , 'argsort' ]
0 commit comments