Skip to content

Commit aba71e8

Browse files
committed
ENH: add cumulative_prod
1 parent beac55b commit aba71e8

File tree

5 files changed

+34
-0
lines changed

5 files changed

+34
-0
lines changed

array_api_compat/common/_aliases.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,36 @@ def cumulative_sum(
297297
)
298298
return res
299299

300+
301+
def cumulative_prod(
302+
x: ndarray,
303+
/,
304+
xp,
305+
*,
306+
axis: Optional[int] = None,
307+
dtype: Optional[Dtype] = None,
308+
include_initial: bool = False,
309+
**kwargs
310+
) -> ndarray:
311+
wrapped_xp = array_namespace(x)
312+
313+
if axis is None:
314+
if x.ndim > 1:
315+
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
316+
axis = 0
317+
318+
res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
319+
320+
# np.cumprod does not support include_initial
321+
if include_initial:
322+
initial_shape = list(x.shape)
323+
initial_shape[axis] = 1
324+
res = xp.concatenate(
325+
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
326+
axis=axis,
327+
)
328+
return res
329+
300330
# The min and max argument names in clip are different and not optional in numpy, and type
301331
# promotion behavior is different.
302332
def clip(

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
std = get_xp(cp)(_aliases.std)
5151
var = get_xp(cp)(_aliases.var)
5252
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
53+
cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
5354
clip = get_xp(cp)(_aliases.clip)
5455
permute_dims = get_xp(cp)(_aliases.permute_dims)
5556
reshape = get_xp(cp)(_aliases.reshape)

array_api_compat/dask/array/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _dask_arange(
8686
std = get_xp(da)(_aliases.std)
8787
var = get_xp(da)(_aliases.var)
8888
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
89+
cumulative_prod = get_xp(da)(_aliases.cumulative_prod)
8990
empty = get_xp(da)(_aliases.empty)
9091
empty_like = get_xp(da)(_aliases.empty_like)
9192
full = get_xp(da)(_aliases.full)

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
std = get_xp(np)(_aliases.std)
5151
var = get_xp(np)(_aliases.var)
5252
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
53+
cumulative_prod = get_xp(np)(_aliases.cumulative_prod)
5354
clip = get_xp(np)(_aliases.clip)
5455
permute_dims = get_xp(np)(_aliases.permute_dims)
5556
reshape = get_xp(np)(_aliases.reshape)

array_api_compat/torch/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
204204
clip = get_xp(torch)(_aliases_clip)
205205
unstack = get_xp(torch)(_aliases_unstack)
206206
cumulative_sum = get_xp(torch)(_aliases_cumulative_sum)
207+
cumulative_prod = get_xp(torch)(_aliases_cumulative_prod)
207208

208209
# torch.sort also returns a tuple
209210
# https://github.com/pytorch/pytorch/issues/70921

0 commit comments

Comments
 (0)