Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,48 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
return ret


def median(x: TensorLike, axis=None) -> TensorVariable:
"""
Computes the median along the given axis(es) of a tensor `input`.

Parameters
----------
x: TensorVariable
The input tensor.
axis: None or int or (list of int) (see `Sum`)
Compute the median along this axis of the tensor.
None means all axes (like numpy).
"""
from pytensor.ifelse import ifelse

x = as_tensor_variable(x)
x_ndim = x.type.ndim
if axis is None:
axis = list(range(x_ndim))
else:
axis = list(normalize_axis_tuple(axis, x_ndim))

non_axis = [i for i in range(x_ndim) if i not in axis]
non_axis_shape = [x.shape[i] for i in non_axis]

# Put axis at the end and unravel them
x_raveled = x.transpose(*non_axis, *axis)
if len(axis) > 1:
x_raveled = x_raveled.reshape((*non_axis_shape, -1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a small optimization to avoid reshaping when not needed

raveled_size = x_raveled.shape[-1]
k = raveled_size // 2

# Sort the input tensor along the specified axis and pick median value
x_sorted = x_raveled.sort(axis=-1)
k_values = x_sorted[..., k]
km1_values = x_sorted[..., k - 1]
Comment on lines +1602 to +1603
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the indexing, we can use simple indexing instead of take_along_axis

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I did not know we can use simple indexing so conveniently.


even_median = (k_values + km1_values) / 2.0
odd_median = k_values.astype(even_median.type.dtype)
even_k = eq(mod(raveled_size, 2), 0)
return ifelse(even_k, even_median, odd_median, name="median")


@scalar_elemwise(symbolname="scalar_maximum")
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor"""
Expand Down Expand Up @@ -3015,6 +3057,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"sum",
"prod",
"mean",
"median",
"var",
"std",
"std",
Expand Down
31 changes: 31 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
max_and_argmax,
maximum,
mean,
median,
min,
minimum,
mod,
Expand Down Expand Up @@ -3735,3 +3736,33 @@ def test_nan_to_num(nan, posinf, neginf):
out,
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
)


@pytest.mark.parametrize(
"ndim, axis",
[
(2, None),
(2, 1),
(2, (0, 1)),
(3, None),
(3, (1, 2)),
(4, (1, 3, 0)),
],
)
def test_median(ndim, axis):
# Generate random data with both odd and even lengths
shape_even = np.arange(1, ndim + 1) * 2
shape_odd = shape_even - 1

data_even = np.random.rand(*shape_even)
data_odd = np.random.rand(*shape_odd)

x = tensor(dtype="float64", shape=(None,) * ndim)
f = function([x], median(x, axis=axis))
result_odd = f(data_odd)
result_even = f(data_even)
expected_odd = np.median(data_odd, axis=axis)
expected_even = np.median(data_even, axis=axis)

assert np.allclose(result_odd, expected_odd)
assert np.allclose(result_even, expected_even)
Loading