Skip to content

Commit b4e94b0

Browse files
authored
Add new 2023.12 elemwise functions: clip, copysign, hypot, maximum, minimum, signbit. (#583)
1 parent 73bbf5c commit b4e94b0

File tree

5 files changed

+109
-1
lines changed

5 files changed

+109
-1
lines changed

api_status.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
4141
| Data Types | `bool`, `int8`, ... | :white_check_mark: | | |
4242
| Elementwise Functions | `add` | :white_check_mark: | | Example of a binary function |
4343
| | `negative` | :white_check_mark: | | Example of a unary function |
44-
| | _others_ | :white_check_mark: | | Except 2023.12 functions in [#438](https://github.com/cubed-dev/cubed/issues/438) |
44+
| | _others_ | :white_check_mark: | | |
4545
| Indexing | Single-axis | :white_check_mark: | | |
4646
| | Multi-axis | :white_check_mark: | | |
4747
| | Boolean array | :x: | | Shape is data dependent, [#73](https://github.com/cubed-dev/cubed/issues/73) |

cubed/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@
153153
bitwise_right_shift,
154154
bitwise_xor,
155155
ceil,
156+
clip,
156157
conj,
158+
copysign,
157159
cos,
158160
cosh,
159161
divide,
@@ -164,6 +166,7 @@
164166
floor_divide,
165167
greater,
166168
greater_equal,
169+
hypot,
167170
imag,
168171
isfinite,
169172
isinf,
@@ -179,6 +182,8 @@
179182
logical_not,
180183
logical_or,
181184
logical_xor,
185+
maximum,
186+
minimum,
182187
multiply,
183188
negative,
184189
not_equal,
@@ -188,6 +193,7 @@
188193
remainder,
189194
round,
190195
sign,
196+
signbit,
191197
sin,
192198
sinh,
193199
sqrt,
@@ -215,7 +221,9 @@
215221
"bitwise_right_shift",
216222
"bitwise_xor",
217223
"ceil",
224+
"clip",
218225
"conj",
226+
"copysign",
219227
"cos",
220228
"cosh",
221229
"divide",
@@ -226,6 +234,7 @@
226234
"floor_divide",
227235
"greater",
228236
"greater_equal",
237+
"hypot",
229238
"imag",
230239
"isfinite",
231240
"isinf",
@@ -241,6 +250,8 @@
241250
"logical_not",
242251
"logical_or",
243252
"logical_xor",
253+
"maximum",
254+
"minimum",
244255
"multiply",
245256
"negative",
246257
"not_equal",
@@ -250,6 +261,7 @@
250261
"remainder",
251262
"round",
252263
"sign",
264+
"signbit",
253265
"sin",
254266
"sinh",
255267
"sqrt",

cubed/array_api/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@
101101
bitwise_right_shift,
102102
bitwise_xor,
103103
ceil,
104+
clip,
104105
conj,
106+
copysign,
105107
cos,
106108
cosh,
107109
divide,
@@ -112,6 +114,7 @@
112114
floor_divide,
113115
greater,
114116
greater_equal,
117+
hypot,
115118
imag,
116119
isfinite,
117120
isinf,
@@ -127,6 +130,8 @@
127130
logical_not,
128131
logical_or,
129132
logical_xor,
133+
maximum,
134+
minimum,
130135
multiply,
131136
negative,
132137
not_equal,
@@ -136,6 +141,7 @@
136141
remainder,
137142
round,
138143
sign,
144+
signbit,
139145
sin,
140146
sinh,
141147
sqrt,
@@ -163,7 +169,9 @@
163169
"bitwise_right_shift",
164170
"bitwise_xor",
165171
"ceil",
172+
"clip",
166173
"conj",
174+
"copysign",
167175
"cos",
168176
"cosh",
169177
"divide",
@@ -174,6 +182,7 @@
174182
"floor_divide",
175183
"greater",
176184
"greater_equal",
185+
"hypot",
177186
"imag",
178187
"isfinite",
179188
"isinf",
@@ -189,6 +198,8 @@
189198
"logical_not",
190199
"logical_or",
191200
"logical_xor",
201+
"maximum",
202+
"minimum",
192203
"multiply",
193204
"negative",
194205
"not_equal",
@@ -198,6 +209,7 @@
198209
"remainder",
199210
"round",
200211
"sign",
212+
"signbit",
201213
"sin",
202214
"sinh",
203215
"sqrt",

cubed/array_api/elementwise_functions.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from cubed.array_api.array_object import Array
2+
from cubed.array_api.creation_functions import asarray
13
from cubed.array_api.data_type_functions import result_type
24
from cubed.array_api.dtypes import (
35
_boolean_dtypes,
@@ -131,12 +133,50 @@ def ceil(x, /):
131133
return elemwise(nxp.ceil, x, dtype=x.dtype)
132134

133135

136+
def clip(x, /, min=None, max=None):
137+
if (
138+
x.dtype not in _real_numeric_dtypes
139+
or isinstance(min, Array)
140+
and min.dtype not in _real_numeric_dtypes
141+
or isinstance(max, Array)
142+
and max.dtype not in _real_numeric_dtypes
143+
):
144+
raise TypeError("Only real numeric dtypes are allowed in clip")
145+
if not isinstance(min, (int, float, Array, type(None))):
146+
raise TypeError("min must be an None, int, float, or an array")
147+
if not isinstance(max, (int, float, Array, type(None))):
148+
raise TypeError("max must be an None, int, float, or an array")
149+
150+
if min is max is None:
151+
return x
152+
elif min is not None and max is None:
153+
min = asarray(min, spec=x.spec)
154+
return elemwise(nxp.clip, x, min, dtype=x.dtype)
155+
elif min is None and max is not None:
156+
157+
def clip_max(x_, max_):
158+
return nxp.clip(x_, max=max_)
159+
160+
max = asarray(max, spec=x.spec)
161+
return elemwise(clip_max, x, max, dtype=x.dtype)
162+
else: # min is not None and max is not None
163+
min = asarray(min, spec=x.spec)
164+
max = asarray(max, spec=x.spec)
165+
return elemwise(nxp.clip, x, min, max, dtype=x.dtype)
166+
167+
134168
def conj(x, /):
135169
if x.dtype not in _complex_floating_dtypes:
136170
raise TypeError("Only complex floating-point dtypes are allowed in conj")
137171
return elemwise(nxp.conj, x, dtype=x.dtype)
138172

139173

174+
def copysign(x1, x2, /):
175+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
176+
raise TypeError("Only real numeric dtypes are allowed in copysign")
177+
return elemwise(nxp.copysign, x1, x2, dtype=result_type(x1, x2))
178+
179+
140180
def cos(x, /):
141181
if x.dtype not in _floating_dtypes:
142182
raise TypeError("Only floating-point dtypes are allowed in cos")
@@ -194,6 +234,12 @@ def greater_equal(x1, x2, /):
194234
return elemwise(nxp.greater_equal, x1, x2, dtype=nxp.bool)
195235

196236

237+
def hypot(x1, x2, /):
238+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
239+
raise TypeError("Only real numeric dtypes are allowed in hypot")
240+
return elemwise(nxp.hypot, x1, x2, dtype=result_type(x1, x2))
241+
242+
197243
def imag(x, /):
198244
if x.dtype == complex64:
199245
dtype = float32
@@ -284,6 +330,18 @@ def logical_xor(x1, x2, /):
284330
return elemwise(nxp.logical_xor, x1, x2, dtype=nxp.bool)
285331

286332

333+
def maximum(x1, x2, /):
334+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
335+
raise TypeError("Only real numeric dtypes are allowed in maximum")
336+
return elemwise(nxp.maximum, x1, x2, dtype=result_type(x1, x2))
337+
338+
339+
def minimum(x1, x2, /):
340+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
341+
raise TypeError("Only real numeric dtypes are allowed in minimum")
342+
return elemwise(nxp.minimum, x1, x2, dtype=result_type(x1, x2))
343+
344+
287345
def multiply(x1, x2, /):
288346
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
289347
raise TypeError("Only numeric dtypes are allowed in multiply")
@@ -340,6 +398,12 @@ def sign(x, /):
340398
return elemwise(nxp.sign, x, dtype=x.dtype)
341399

342400

401+
def signbit(x, /):
402+
if x.dtype not in _real_numeric_dtypes:
403+
raise TypeError("Only real numeric dtypes are allowed in signbit")
404+
return elemwise(nxp.signbit, x, dtype=nxp.bool)
405+
406+
343407
def sin(x, /):
344408
if x.dtype not in _floating_dtypes:
345409
raise TypeError("Only floating-point dtypes are allowed in sin")

cubed/tests/test_array_api.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,26 @@ def test_add_different_chunks_fail(spec, executor):
194194
assert_array_equal(c.compute(executor=executor), np.ones((10,)) + np.ones((10,)))
195195

196196

197+
@pytest.mark.parametrize(
198+
"min, max",
199+
[
200+
(None, None),
201+
(4, None),
202+
(None, 7),
203+
(4, 7),
204+
(0, 10),
205+
],
206+
)
207+
def test_clip(spec, min, max):
208+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
209+
npa = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
210+
b = xp.clip(a, min, max)
211+
if min is max is None:
212+
assert b is a
213+
else:
214+
assert_array_equal(b.compute(), np.clip(npa, min, max))
215+
216+
197217
def test_equal(spec):
198218
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
199219
b = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)

0 commit comments

Comments
 (0)