Skip to content

Commit 9c2ce26

Browse files
committed
Implement diff for XTensorVariables
1 parent a5f0aab commit 9c2ce26

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

pytensor/xtensor/type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,15 @@ def cumsum(self, dim):
535535
def cumprod(self, dim):
536536
return px.reduction.cumprod(self, dim)
537537

538+
def diff(self, dim, n=1):
539+
"""Compute the n-th discrete difference along the given dimension."""
540+
slice1 = {dim: slice(1, None)}
541+
slice2 = {dim: slice(None, -1)}
542+
x = self
543+
for _ in range(n):
544+
x = x[slice1] - x[slice2]
545+
return x
546+
538547
# Reshaping and reorganizing
539548
# https://docs.xarray.dev/en/latest/api.html#id8
540549
def transpose(

tests/xtensor/test_indexing.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,22 @@ def test_indexing_renames_into_update_variable():
486486
expected_result = x_test.copy()
487487
expected_result[idx_test] = y_test
488488
xr_assert_allclose(result, expected_result)
489+
490+
491+
@pytest.mark.parametrize("n", ["implicit", 1, 2])
492+
@pytest.mark.parametrize("dim", ["a", "b"])
493+
def test_diff(dim, n):
494+
x = xtensor(dims=("a", "b"), shape=(7, 11))
495+
if n == "implicit":
496+
out = x.diff(dim)
497+
else:
498+
out = x.diff(dim, n=n)
499+
500+
fn = xr_function([x], out)
501+
x_test = xr_arange_like(x)
502+
res = fn(x_test)
503+
if n == "implicit":
504+
expected_res = x_test.diff(dim)
505+
else:
506+
expected_res = x_test.diff(dim, n=n)
507+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)