@@ -407,6 +407,43 @@ def _concat_diff_input(arr, axis, prepend, append):
407
407
408
408
409
409
def diff (x , / , * , axis = - 1 , n = 1 , prepend = None , append = None ):
410
+ """
411
+ Calculates the `n`-th discrete forward difference of `x` along `axis`.
412
+
413
+ Args:
414
+ x (usm_ndarray):
415
+ input array.
416
+ axis (int):
417
+ axis along which to compute the difference. A valid axis must be on
418
+ the interval `[-N, N)`, where `N` is the rank (number of
419
+ dimensions) of `x`.
420
+ Default: `-1`
421
+ n (int):
422
+ number of times to recursively compute the difference.
423
+ Default: `1`.
424
+ prepend (Union[usm_ndarray, bool, int, float, complex]):
425
+ value or values to prepend to the specified axis before taking the
426
+ difference.
427
+ Must have the same shape as `x` except along `axis`, which can have
428
+ any shape.
429
+ Default: `None`.
430
+ append (Union[usm_ndarray, bool, int, float, complex]):
431
+ value or values to append to the specified axis before taking the
432
+ difference.
433
+ Must have the same shape as `x` except along `axis`, which can have
434
+ any shape.
435
+ Default: `None`.
436
+
437
+ Returns:
438
+ usm_ndarray:
439
+ an array containing the `n`-th differences. The array will have the
440
+ same shape as `x`, except along `axis`, which will have shape
441
+
442
+ - prepend.shape[axis] + x.shape[axis] + append.shape[axis] - n
443
+
444
+ The data type of the returned array is determined by the Type
445
+ Promotion Rules.
446
+ """
410
447
411
448
if not isinstance (x , dpt .usm_ndarray ):
412
449
raise TypeError (
@@ -417,7 +454,8 @@ def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
417
454
n = operator .index (n )
418
455
419
456
arr = _concat_diff_input (x , axis , prepend , append )
420
-
457
+ if n == 0 :
458
+ return arr
421
459
# form slices and recurse
422
460
sl0 = tuple (
423
461
slice (None ) if i != axis else slice (1 , None ) for i in range (x_nd )
0 commit comments