Skip to content

Commit e9d0f5a

Browse files
committed
cosmetic
1 parent 980daf0 commit e9d0f5a

File tree

1 file changed

+37
-23
lines changed

1 file changed

+37
-23
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,25 @@
4040
isdtype = get_xp(np)(_aliases.isdtype)
4141
unstack = get_xp(da)(_aliases.unstack)
4242

43+
# da.astype doesn't respect copy=True
4344
def astype(
4445
x: Array,
4546
dtype: Dtype,
4647
/,
4748
*,
4849
copy: bool = True,
49-
device: Device | None = None
50+
device: Optional[Device] = None
5051
) -> Array:
52+
"""
53+
Array API compatibility wrapper for astype().
54+
55+
See the corresponding documentation in the array library and/or the array API
56+
specification for more details.
57+
"""
5158
# TODO: respect device keyword?
59+
5260
if not copy and dtype == x.dtype:
5361
return x
54-
# dask astype doesn't respect copy=True,
55-
# so call copy manually afterwards
5662
x = x.astype(dtype)
5763
return x.copy() if copy else x
5864

@@ -61,20 +67,24 @@ def astype(
6167
# This arange func is modified from the common one to
6268
# not pass stop/step as keyword arguments, which will cause
6369
# an error with dask
64-
65-
# TODO: delete the xp stuff, it shouldn't be necessary
66-
def _dask_arange(
70+
def arange(
6771
start: Union[int, float],
6872
/,
6973
stop: Optional[Union[int, float]] = None,
7074
step: Union[int, float] = 1,
7175
*,
72-
xp,
7376
dtype: Optional[Dtype] = None,
7477
device: Optional[Device] = None,
7578
**kwargs,
7679
) -> Array:
77-
_check_device(xp, device)
80+
"""
81+
Array API compatibility wrapper for arange().
82+
83+
See the corresponding documentation in the array library and/or the array API
84+
specification for more details.
85+
"""
86+
# TODO: respect device keyword?
87+
7888
args = [start]
7989
if stop is not None:
8090
args.append(stop)
@@ -83,13 +93,12 @@ def _dask_arange(
8393
# prepend the default value for start which is 0
8494
args.insert(0, 0)
8595
args.append(step)
86-
return xp.arange(*args, dtype=dtype, **kwargs)
8796

88-
arange = get_xp(da)(_dask_arange)
89-
eye = get_xp(da)(_aliases.eye)
97+
return da.arange(*args, dtype=dtype, **kwargs)
98+
9099

91-
linspace = get_xp(da)(_aliases.linspace)
92100
eye = get_xp(da)(_aliases.eye)
101+
linspace = get_xp(da)(_aliases.linspace)
93102
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
94103
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
95104
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
@@ -112,7 +121,6 @@ def _dask_arange(
112121
reshape = get_xp(da)(_aliases.reshape)
113122
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
114123
vecdot = get_xp(da)(_aliases.vecdot)
115-
116124
nonzero = get_xp(da)(_aliases.nonzero)
117125
ceil = get_xp(np)(_aliases.ceil)
118126
floor = get_xp(np)(_aliases.floor)
@@ -121,6 +129,7 @@ def _dask_arange(
121129
tensordot = get_xp(np)(_aliases.tensordot)
122130
sign = get_xp(np)(_aliases.sign)
123131

132+
124133
# asarray also adds the copy keyword, which is not present in numpy 1.0.
125134
def asarray(
126135
obj: Union[
@@ -135,7 +144,7 @@ def asarray(
135144
*,
136145
dtype: Optional[Dtype] = None,
137146
device: Optional[Device] = None,
138-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
147+
copy: Optional[Union[bool, np._CopyMode]] = None,
139148
**kwargs,
140149
) -> Array:
141150
"""
@@ -144,6 +153,8 @@ def asarray(
144153
See the corresponding documentation in the array library and/or the array API
145154
specification for more details.
146155
"""
156+
# TODO: respect device keyword?
157+
147158
if isinstance(obj, da.Array):
148159
if dtype is not None and dtype != obj.dtype:
149160
if copy is False:
@@ -183,15 +194,18 @@ def asarray(
183194
# Furthermore, the masking workaround in common._aliases.clip cannot work with
184195
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
185196
# now).
186-
@get_xp(da)
187197
def clip(
188198
x: Array,
189199
/,
190200
min: Optional[Union[int, float, Array]] = None,
191201
max: Optional[Union[int, float, Array]] = None,
192-
*,
193-
xp,
194202
) -> Array:
203+
"""
204+
Array API compatibility wrapper for clip().
205+
206+
See the corresponding documentation in the array library and/or the array API
207+
specification for more details.
208+
"""
195209
def _isscalar(a):
196210
return isinstance(a, (int, float, type(None)))
197211
min_shape = () if _isscalar(min) else min.shape
@@ -201,19 +215,19 @@ def _isscalar(a):
201215
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
202216

203217
if min is not None:
204-
min = xp.broadcast_to(xp.asarray(min), result_shape)
218+
min = da.broadcast_to(da.asarray(min), result_shape)
205219
if max is not None:
206-
max = xp.broadcast_to(xp.asarray(max), result_shape)
220+
max = da.broadcast_to(da.asarray(max), result_shape)
207221

208222
if min is None and max is None:
209-
return xp.positive(x)
223+
return da.positive(x)
210224

211225
if min is None:
212-
return astype(xp.minimum(x, max), x.dtype)
226+
return astype(da.minimum(x, max), x.dtype)
213227
if max is None:
214-
return astype(xp.maximum(x, min), x.dtype)
228+
return astype(da.maximum(x, min), x.dtype)
215229

216-
return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)
230+
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
217231

218232
# exclude these from all since dask.array has no sorting functions
219233
_da_unsupported = ['sort', 'argsort']

0 commit comments

Comments
 (0)