Skip to content

Commit f81c2f2

Browse files
authored
Array API 2024 additions (#779)
* Add 'max dimensions' to capabilities * Add `nextafter` * Add `reciprocal` * Add `count_nonzero`
1 parent 1e5d9e5 commit f81c2f2

File tree

6 files changed

+50
-10
lines changed

6 files changed

+50
-10
lines changed

api_status.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
6969
| | `unstack` | :white_check_mark: | 2023.12 | |
7070
| Searching Functions | `argmax` | :white_check_mark: | | |
7171
| | `argmin` | :white_check_mark: | | |
72-
| | `count_nonzero` | :x: | 2024.12 | |
72+
| | `count_nonzero` | :white_check_mark: | 2024.12 | |
7373
| | `nonzero` | :x: | | Shape is data dependent |
7474
| | `searchsorted` | :white_check_mark: | 2023.12 | |
7575
| | `where` | :white_check_mark: | | |

cubed/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,12 @@
193193
minimum,
194194
multiply,
195195
negative,
196+
nextafter,
196197
not_equal,
197198
positive,
198199
pow,
199200
real,
201+
reciprocal,
200202
remainder,
201203
round,
202204
sign,
@@ -260,11 +262,13 @@
260262
"maximum",
261263
"minimum",
262264
"multiply",
265+
"nextafter",
263266
"negative",
264267
"not_equal",
265268
"positive",
266269
"pow",
267270
"real",
271+
"reciprocal",
268272
"remainder",
269273
"round",
270274
"sign",
@@ -326,9 +330,15 @@
326330
"unstack",
327331
]
328332

329-
from .array_api.searching_functions import argmax, argmin, searchsorted, where
333+
from .array_api.searching_functions import (
334+
argmax,
335+
argmin,
336+
count_nonzero,
337+
searchsorted,
338+
where,
339+
)
330340

331-
__all__ += ["argmax", "argmin", "searchsorted", "where"]
341+
__all__ += ["argmax", "argmin", "count_nonzero", "searchsorted", "where"]
332342

333343
from .array_api.statistical_functions import (
334344
cumulative_sum,

cubed/array_api/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,12 @@
136136
minimum,
137137
multiply,
138138
negative,
139+
nextafter,
139140
not_equal,
140141
positive,
141142
pow,
142143
real,
144+
reciprocal,
143145
remainder,
144146
round,
145147
sign,
@@ -203,11 +205,13 @@
203205
"maximum",
204206
"minimum",
205207
"multiply",
208+
"nextafter",
206209
"negative",
207210
"not_equal",
208211
"positive",
209212
"pow",
210213
"real",
214+
"reciprocal",
211215
"remainder",
212216
"round",
213217
"sign",
@@ -264,9 +268,9 @@
264268
"unstack",
265269
]
266270

267-
from .searching_functions import argmax, argmin, searchsorted, where
271+
from .searching_functions import argmax, argmin, count_nonzero, searchsorted, where
268272

269-
__all__ += ["argmax", "argmin", "searchsorted", "where"]
273+
__all__ += ["argmax", "argmin", "count_nonzero", "searchsorted", "where"]
270274

271275
from .statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
272276

cubed/array_api/elementwise_functions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,13 @@ def multiply(x1, x2, /):
372372
return elemwise(nxp.multiply, x1, x2, dtype=result_type(x1, x2))
373373

374374

375+
def nextafter(x1, x2, /):
376+
x1, x2 = _promote_scalars(x1, x2, "nextafter")
377+
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
378+
raise TypeError("Only real floating-point dtypes are allowed in nextafter")
379+
return elemwise(nxp.nextafter, x1, x2, dtype=x1.dtype)
380+
381+
375382
def negative(x, /):
376383
if x.dtype not in _numeric_dtypes:
377384
raise TypeError("Only numeric dtypes are allowed in negative")
@@ -406,6 +413,12 @@ def real(x, /):
406413
return elemwise(nxp.real, x, dtype=dtype)
407414

408415

416+
def reciprocal(x, /):
417+
if x.dtype not in _floating_dtypes:
418+
raise TypeError("Only floating-point dtypes are allowed in reciprocal")
419+
return elemwise(nxp.reciprocal, x, dtype=x.dtype)
420+
421+
409422
def remainder(x1, x2, /):
410423
x1, x2 = _promote_scalars(x1, x2, "remainder")
411424
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:

cubed/array_api/inspection.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33

44
class __array_namespace_info__:
5-
# capabilities are determined by Cubed, not the backend array API
65
def capabilities(self):
76
return {
8-
"boolean indexing": False,
9-
"data-dependent shapes": False,
7+
"boolean indexing": False, # not supported in Cubed (#73)
8+
"data-dependent shapes": False, # not supported in Cubed
9+
"max dimensions": nxp.__array_namespace_info__().capabilities()[
10+
"max dimensions"
11+
],
1012
}
1113

1214
# devices and dtypes are determined by the backend array API

cubed/array_api/searching_functions.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from cubed.array_api.creation_functions import asarray, zeros_like
2-
from cubed.array_api.data_type_functions import result_type
2+
from cubed.array_api.data_type_functions import astype, result_type
33
from cubed.array_api.dtypes import _promote_scalars, _real_numeric_dtypes
44
from cubed.array_api.manipulation_functions import reshape
5-
from cubed.array_api.statistical_functions import max
5+
from cubed.array_api.statistical_functions import max, sum
66
from cubed.backend_array_api import namespace as nxp
77
from cubed.core.ops import arg_reduction, blockwise, elemwise
88

@@ -39,6 +39,17 @@ def argmin(x, /, *, axis=None, keepdims=False, split_every=None):
3939
)
4040

4141

42+
def count_nonzero(x, /, *, axis=None, keepdims=False, split_every=None):
43+
dtype = nxp.__array_namespace_info__().default_dtypes(device=x.device)["indexing"]
44+
return sum(
45+
astype(x, nxp.bool),
46+
axis=axis,
47+
dtype=dtype,
48+
keepdims=keepdims,
49+
split_every=split_every,
50+
)
51+
52+
4253
def searchsorted(x1, x2, /, *, side="left", sorter=None):
4354
if x1.ndim != 1:
4455
raise ValueError("Input array x1 must be one dimensional")

0 commit comments

Comments
 (0)