@@ -409,6 +409,43 @@ def _concat_diff_input(arr, axis, prepend, append):
409
409
410
410
411
411
def diff (x , / , * , axis = - 1 , n = 1 , prepend = None , append = None ):
412
+ """
413
+ Calculates the `n`-th discrete forward difference of `x` along `axis`.
414
+
415
+ Args:
416
+ x (usm_ndarray):
417
+ input array.
418
+ axis (int):
419
+ axis along which to compute the difference. A valid axis must be on
420
+ the interval `[-N, N)`, where `N` is the rank (number of
421
+ dimensions) of `x`.
422
+ Default: `-1`
423
+ n (int):
424
+ number of times to recursively compute the difference.
425
+ Default: `1`.
426
+ prepend (Union[usm_ndarray, bool, int, float, complex]):
427
+ value or values to prepend to the specified axis before taking the
428
+ difference.
429
+ Must have the same shape as `x` except along `axis`, which can have
430
+ any shape.
431
+ Default: `None`.
432
+ append (Union[usm_ndarray, bool, int, float, complex]):
433
+ value or values to append to the specified axis before taking the
434
+ difference.
435
+ Must have the same shape as `x` except along `axis`, which can have
436
+ any shape.
437
+ Default: `None`.
438
+
439
+ Returns:
440
+ usm_ndarray:
441
+ an array containing the `n`-th differences. The array will have the
442
+ same shape as `x`, except along `axis`, which will have shape
443
+
444
+ - prepend.shape[axis] + x.shape[axis] + append.shape[axis] - n
445
+
446
+ The data type of the returned array is determined by the Type
447
+ Promotion Rules.
448
+ """
412
449
413
450
if not isinstance (x , dpt .usm_ndarray ):
414
451
raise TypeError (
@@ -419,7 +456,8 @@ def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
419
456
n = operator .index (n )
420
457
421
458
arr = _concat_diff_input (x , axis , prepend , append )
422
-
459
+ if n == 0 :
460
+ return arr
423
461
# form slices and recurse
424
462
sl0 = tuple (
425
463
slice (None ) if i != axis else slice (1 , None ) for i in range (x_nd )
0 commit comments