Skip to content

Commit ab233c9

Browse files
authored
Add cumulative_prod (#795)
1 parent 4dc6ac9 commit ab233c9

File tree

6 files changed

+83
-5
lines changed

6 files changed

+83
-5
lines changed

.github/workflows/array-api-tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ jobs:
114114
array_api_tests/test_array_object.py::test_getitem
115115
# test_searchsorted depends on sort which is not implemented
116116
array_api_tests/test_searching_functions.py::test_searchsorted
117-
# cumulative_sum with include_initial=True is not implemented
117+
# cumulative_* functions with include_initial=True are not implemented
118+
array_api_tests/test_statistical_functions.py::test_cumulative_prod
118119
array_api_tests/test_statistical_functions.py::test_cumulative_sum
119120
120121
# not implemented

api_status.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
7979
| | `unique_values` | :x: | | Shape is data dependent |
8080
| Sorting Functions | `argsort` | :x: | | |
8181
| | `sort` | :x: | | |
82-
| Statistical Functions | `cumulative_prod` | :x: | 2024.12 | |
82+
| Statistical Functions | `cumulative_prod` | :white_check_mark: | 2024.12 | |
8383
| | `cumulative_sum` | :white_check_mark: | 2023.12 | |
8484
| | `max` | :white_check_mark: | | |
8585
| | `mean` | :white_check_mark: | | |

cubed/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@
341341
__all__ += ["argmax", "argmin", "count_nonzero", "searchsorted", "where"]
342342

343343
from .array_api.statistical_functions import (
344+
cumulative_prod,
344345
cumulative_sum,
345346
max,
346347
mean,
@@ -351,7 +352,17 @@
351352
var,
352353
)
353354

354-
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
355+
__all__ += [
356+
"cumulative_prod",
357+
"cumulative_sum",
358+
"max",
359+
"mean",
360+
"min",
361+
"prod",
362+
"std",
363+
"sum",
364+
"var",
365+
]
355366

356367
from .array_api.utility_functions import all, any, diff
357368

cubed/array_api/__init__.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,29 @@
272272

273273
__all__ += ["argmax", "argmin", "count_nonzero", "searchsorted", "where"]
274274

275-
from .statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
275+
from .statistical_functions import (
276+
cumulative_prod,
277+
cumulative_sum,
278+
max,
279+
mean,
280+
min,
281+
prod,
282+
std,
283+
sum,
284+
var,
285+
)
276286

277-
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
287+
__all__ += [
288+
"cumulative_prod",
289+
"cumulative_sum",
290+
"max",
291+
"mean",
292+
"min",
293+
"prod",
294+
"std",
295+
"sum",
296+
"var",
297+
]
278298

279299
from .utility_functions import all, any, diff
280300

cubed/array_api/statistical_functions.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,42 @@
1010
from cubed.core import reduction, scan
1111

1212

13+
def cumulative_prod(x, /, *, axis=None, dtype=None, include_initial=False, device=None):
14+
if include_initial:
15+
raise NotImplementedError("include_initial is not supported in cumulative_prod")
16+
dtype = _upcast_integral_dtypes(
17+
x,
18+
dtype,
19+
allowed_dtypes=(
20+
"numeric",
21+
"boolean",
22+
),
23+
fname="cumulative_prod",
24+
device=device,
25+
)
26+
return scan(
27+
x,
28+
preop=nxp.prod,
29+
func=_cumulative_prod_func,
30+
binop=nxp.multiply,
31+
axis=axis,
32+
dtype=dtype,
33+
)
34+
35+
36+
def _cumulative_prod_func(a, /, *, axis=None, dtype=None, include_initial=False):
37+
out = nxp.cumulative_prod(
38+
a, axis=axis, dtype=dtype, include_initial=include_initial
39+
)
40+
if include_initial:
41+
# we don't yet support including the final element as it complicates chunk sizing
42+
ind = tuple(
43+
slice(a.shape[i]) if i == axis else slice(None) for i in range(a.ndim)
44+
)
45+
out = out[ind]
46+
return out
47+
48+
1349
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device=None):
1450
if include_initial:
1551
raise NotImplementedError("include_initial is not supported in cumulative_sum")

cubed/tests/test_array_api.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,16 @@ def test_where_scalars():
850850
# Statistical functions
851851

852852

853+
@pytest.mark.parametrize("axis", [0, 1])
854+
def test_cumulative_prod_2d(axis):
855+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))
856+
b = xp.cumulative_prod(a, axis=axis)
857+
assert_array_equal(
858+
b.compute(),
859+
np.cumulative_prod(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), axis=axis),
860+
)
861+
862+
853863
@pytest.mark.parametrize("axis", [0, 1])
854864
def test_cumulative_sum_2d(axis):
855865
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))

0 commit comments

Comments
 (0)