Skip to content

Commit 3315646

Browse files
committed
Adds docstrings for diff and count_nonzero
1 parent 8c0f905 commit 3315646

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

dpctl/tensor/_reduction.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,37 @@ def argmin(x, /, *, axis=None, keepdims=False, out=None):
776776

777777

778778
def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
779+
"""
780+
Counts the number of elements in the input array ``x`` which are non-zero.
781+
782+
Args:
783+
x (usm_ndarray):
784+
input array.
785+
axis (Optional[int, Tuple[int, ...]]):
786+
axis or axes along which to count. If a tuple of unique integers,
787+
the number of non-zero values are computed over multiple axes.
788+
If ``None``, the number of non-zero values is computed over the
789+
entire array.
790+
Default: ``None``.
791+
keepdims (Optional[bool]):
792+
if ``True``, the reduced axes (dimensions) are included in the
793+
result as singleton dimensions, so that the returned array remains
794+
compatible with the input arrays according to Array Broadcasting
795+
rules. Otherwise, if ``False``, the reduced axes are not included
796+
in the returned array. Default: ``False``.
797+
out (Optional[usm_ndarray]):
798+
the array into which the result is written.
799+
The data type of ``out`` must match the expected shape and data
800+
type.
801+
If ``None`` then a new array is returned. Default: ``None``.
802+
803+
Returns:
804+
usm_ndarray:
805+
an array containing the count of non-zero values. If the sum was
806+
computed over the entire array, a zero-dimensional array is
807+
returned. The returned array will have the default array index data
808+
type.
809+
"""
779810
if x.dtype != dpt.bool:
780811
x = dpt.astype(x, dpt.bool, copy=False)
781812
return sum(

dpctl/tensor/_utility_functions.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,43 @@ def _concat_diff_input(arr, axis, prepend, append):
407407

408408

409409
def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
410+
"""
411+
Calculates the `n`-th discrete forward difference of `x` along `axis`.
412+
413+
Args:
414+
x (usm_ndarray):
415+
input array.
416+
axis (int):
417+
axis along which to compute the difference. A valid axis must be on
418+
the interval `[-N, N)`, where `N` is the rank (number of
419+
dimensions) of `x`.
420+
Default: `-1`
421+
n (int):
422+
number of times to recursively compute the difference.
423+
Default: `1`.
424+
prepend (Union[usm_ndarray, bool, int, float, complex]):
425+
value or values to prepend to the specified axis before taking the
426+
difference.
427+
Must have the same shape as `x` except along `axis`, which can have
428+
any shape.
429+
Default: `None`.
430+
append (Union[usm_ndarray, bool, int, float, complex]):
431+
value or values to append to the specified axis before taking the
432+
difference.
433+
Must have the same shape as `x` except along `axis`, which can have
434+
any shape.
435+
Default: `None`.
436+
437+
Returns:
438+
usm_ndarray:
439+
an array containing the `n`-th differences. The array will have the
440+
same shape as `x`, except along `axis`, which will have shape
441+
442+
- prepend.shape[axis] + x.shape[axis] + append.shape[axis] - n
443+
444+
The data type of the returned array is determined by the Type
445+
Promotion Rules.
446+
"""
410447

411448
if not isinstance(x, dpt.usm_ndarray):
412449
raise TypeError(
@@ -417,7 +454,8 @@ def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
417454
n = operator.index(n)
418455

419456
arr = _concat_diff_input(x, axis, prepend, append)
420-
457+
if n == 0:
458+
return arr
421459
# form slices and recurse
422460
sl0 = tuple(
423461
slice(None) if i != axis else slice(1, None) for i in range(x_nd)

0 commit comments

Comments
 (0)