Skip to content

Commit f9bf5ae

Browse files
committed
Docstrings for diff and count_nonzero
1 parent 5524b7f commit f9bf5ae

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

dpctl/tensor/_reduction.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,37 @@ def argmin(x, /, *, axis=None, keepdims=False, out=None):
775775

776776

777777
def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
778+
"""
779+
Counts the number of elements in the input array ``x`` which are non-zero.
780+
781+
Args:
782+
x (usm_ndarray):
783+
input array.
784+
axis (Optional[int, Tuple[int, ...]]):
785+
axis or axes along which to count. If a tuple of unique integers,
786+
the number of non-zero values are computed over multiple axes.
787+
If ``None``, the number of non-zero values is computed over the
788+
entire array.
789+
Default: ``None``.
790+
keepdims (Optional[bool]):
791+
if ``True``, the reduced axes (dimensions) are included in the
792+
result as singleton dimensions, so that the returned array remains
793+
compatible with the input arrays according to Array Broadcasting
794+
rules. Otherwise, if ``False``, the reduced axes are not included
795+
in the returned array. Default: ``False``.
796+
out (Optional[usm_ndarray]):
797+
the array into which the result is written.
798+
The data type of ``out`` must match the expected shape and data
799+
type.
800+
If ``None`` then a new array is returned. Default: ``None``.
801+
802+
Returns:
803+
usm_ndarray:
804+
an array containing the count of non-zero values. If the sum was
805+
computed over the entire array, a zero-dimensional array is
806+
returned. The returned array will have the default array index data
807+
type.
808+
"""
778809
if x.dtype != dpt.bool:
779810
x = dpt.astype(x, dpt.bool, copy=False)
780811
return sum(

dpctl/tensor/_utility_functions.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,43 @@ def _concat_diff_input(arr, axis, prepend, append):
409409

410410

411411
def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
412+
"""
413+
Calculates the `n`-th discrete forward difference of `x` along `axis`.
414+
415+
Args:
416+
x (usm_ndarray):
417+
input array.
418+
axis (int):
419+
axis along which to compute the difference. A valid axis must be on
420+
the interval `[-N, N)`, where `N` is the rank (number of
421+
dimensions) of `x`.
422+
Default: `-1`
423+
n (int):
424+
number of times to recursively compute the difference.
425+
Default: `1`.
426+
prepend (Union[usm_ndarray, bool, int, float, complex]):
427+
value or values to prepend to the specified axis before taking the
428+
difference.
429+
Must have the same shape as `x` except along `axis`, which can have
430+
any shape.
431+
Default: `None`.
432+
append (Union[usm_ndarray, bool, int, float, complex]):
433+
value or values to append to the specified axis before taking the
434+
difference.
435+
Must have the same shape as `x` except along `axis`, which can have
436+
any shape.
437+
Default: `None`.
438+
439+
Returns:
440+
usm_ndarray:
441+
an array containing the `n`-th differences. The array will have the
442+
same shape as `x`, except along `axis`, which will have shape
443+
444+
- prepend.shape[axis] + x.shape[axis] + append.shape[axis] - n
445+
446+
The data type of the returned array is determined by the Type
447+
Promotion Rules.
448+
"""
412449

413450
if not isinstance(x, dpt.usm_ndarray):
414451
raise TypeError(
@@ -419,7 +456,8 @@ def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
419456
n = operator.index(n)
420457

421458
arr = _concat_diff_input(x, axis, prepend, append)
422-
459+
if n == 0:
460+
return arr
423461
# form slices and recurse
424462
sl0 = tuple(
425463
slice(None) if i != axis else slice(1, None) for i in range(x_nd)

0 commit comments

Comments
 (0)