Skip to content

Commit 8cc14d8

Browse files
committed
Rewrite diff to juggle between two larger temporary allocations when n > 2
Saves some temporary memory allocations
1 parent 2fef994 commit 8cc14d8

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

dpctl/tensor/_utility_functions.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,20 @@ def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
465465
)
466466

467467
diff_op = dpt.not_equal if x.dtype == dpt.bool else dpt.subtract
468-
for _ in range(n):
468+
if n > 1:
469+
arr_tmp0 = diff_op(arr[sl0], arr[sl1])
470+
arr_tmp1 = diff_op(arr_tmp0[sl0], arr_tmp0[sl1])
471+
n = n - 2
472+
if n > 0:
473+
sl3 = tuple(
474+
slice(None) if i != axis else slice(None, -2)
475+
for i in range(x_nd)
476+
)
477+
for _ in range(n):
478+
arr_tmp0_sliced = arr_tmp0[sl3]
479+
diff_op(arr_tmp1[sl0], arr_tmp1[sl1], out=arr_tmp0_sliced)
480+
arr_tmp0, arr_tmp1 = arr_tmp1, arr_tmp0_sliced
481+
arr = arr_tmp1
482+
else:
469483
arr = diff_op(arr[sl0], arr[sl1])
470-
471484
return arr

0 commit comments

Comments
 (0)