Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@
from keras.src.ops.numpy import nancumsum as nancumsum
from keras.src.ops.numpy import nanmax as nanmax
from keras.src.ops.numpy import nanmean as nanmean
from keras.src.ops.numpy import nanmedian as nanmedian
from keras.src.ops.numpy import nanmin as nanmin
from keras.src.ops.numpy import nanprod as nanprod
from keras.src.ops.numpy import nanquantile as nanquantile
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from keras.src.ops.numpy import nancumsum as nancumsum
from keras.src.ops.numpy import nanmax as nanmax
from keras.src.ops.numpy import nanmean as nanmean
from keras.src.ops.numpy import nanmedian as nanmedian
from keras.src.ops.numpy import nanmin as nanmin
from keras.src.ops.numpy import nanprod as nanprod
from keras.src.ops.numpy import nanquantile as nanquantile
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@
from keras.src.ops.numpy import nancumsum as nancumsum
from keras.src.ops.numpy import nanmax as nanmax
from keras.src.ops.numpy import nanmean as nanmean
from keras.src.ops.numpy import nanmedian as nanmedian
from keras.src.ops.numpy import nanmin as nanmin
from keras.src.ops.numpy import nanprod as nanprod
from keras.src.ops.numpy import nanquantile as nanquantile
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from keras.src.ops.numpy import nancumsum as nancumsum
from keras.src.ops.numpy import nanmax as nanmax
from keras.src.ops.numpy import nanmean as nanmean
from keras.src.ops.numpy import nanmedian as nanmedian
from keras.src.ops.numpy import nanmin as nanmin
from keras.src.ops.numpy import nanprod as nanprod
from keras.src.ops.numpy import nanquantile as nanquantile
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,11 @@ def nanmean(x, axis=None, keepdims=False):
return jnp.nanmean(x, axis=axis, keepdims=keepdims)


def nanmedian(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
return jnp.nanmedian(x, axis=axis, keepdims=keepdims)


def nanmin(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
return jnp.nanmin(x, axis=axis, keepdims=keepdims)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,11 @@ def nanmean(x, axis=None, keepdims=False):
return np.nanmean(x, axis=axis, keepdims=keepdims).astype(dtype)


def nanmedian(x, axis=None, keepdims=False):
dtype = dtypes.result_type(standardize_dtype(x.dtype), float)
return np.nanmedian(x, axis=axis, keepdims=keepdims).astype(dtype)


def nanmin(x, axis=None, keepdims=False):
return np.nanmin(x, axis=axis, keepdims=keepdims)

Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ NNOpsCorrectnessTest::test_sparsemax
NNOpsDtypeTest::test_ctc_decode
NNOpsDtypeTest::test_glu_
NNOpsDtypeTest::test_polar_
NumpyDtypeTest::test_nanmedian
NumpyOneInputOpsCorrectnessTest::test_conj
NumpyOneInputOpsCorrectnessTest::test_imag
NumpyOneInputOpsCorrectnessTest::test_isreal
NumpyOneInputOpsCorrectnessTest::test_nanmedian
NumpyOneInputOpsCorrectnessTest::test_real
QuantizersTest::test_compute_float8_scale
QuantizersTest::test_grouped_quantize_with_padding
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3241,6 +3241,12 @@ def nanmean(x, axis=None, keepdims=False):
return OpenVINOKerasTensor(result)


def nanmedian(x, axis=None, keepdims=False):
raise NotImplementedError(
"`nanmedian` is not supported with openvino backend"
)


def nanmin(x, axis=None, keepdims=False):
if isinstance(x, np.ndarray) and x.dtype == np.float64:
# conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138
Expand Down
9 changes: 9 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,15 @@ def nanmean(x, axis=None, keepdims=False):
return tf.divide(total_sum, normalizer)


def nanmedian(x, axis=None, keepdims=False):
x = convert_to_tensor(x)

if axis == () or axis == []:
return x

return nanquantile(x, q=0.5, axis=axis, keepdims=keepdims)


def nanmin(x, axis=None, keepdims=False):
x = convert_to_tensor(x)

Expand Down
9 changes: 9 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,15 @@ def nanmean(x, axis=None, keepdims=False):
return torch.nanmean(cast(x, dtype), dim=axis, keepdim=keepdims)


def nanmedian(x, axis=None, keepdims=False):
x = convert_to_tensor(x)

if axis == () or axis == []:
return x

return nanquantile(x, q=0.5, axis=axis, keepdims=keepdims)


def nanmin(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
if not torch.is_floating_point(x):
Expand Down
58 changes: 58 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5834,6 +5834,64 @@ def nanmean(x, axis=None, keepdims=False):
return backend.numpy.nanmean(x, axis=axis, keepdims=keepdims)


class Nanmedian(Operation):
def __init__(self, axis=None, keepdims=False, *, name=None):
super().__init__(name=name)
self.axis = axis
self.keepdims = keepdims

def call(self, x):
return backend.numpy.nanmedian(
x, axis=self.axis, keepdims=self.keepdims
)

def compute_output_spec(self, x):
dtype = dtypes.result_type(x.dtype, float)
return KerasTensor(
reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),
dtype=dtype,
)


@keras_export(["keras.ops.nanmedian", "keras.ops.numpy.nanmedian"])
def nanmedian(x, axis=None, keepdims=False):
"""Median of a tensor over the given axes, ignoring NaNs.

This function computes the median along the specified axis or axes,
skipping any NaN values. If all values along a reduced axis are NaN,
the result is NaN.

Args:
x: Input tensor.
axis: Axis or axes along which the median is computed.
If None (default), the median of the flattened tensor is returned.
keepdims: If True, the reduced axes are retained as dimensions
with size one. Defaults to False.

Returns:
Tensor with the median values, ignoring NaNs.

Examples:
>>> import numpy as np
>>> from keras import ops
>>> x = np.array([[1.0, np.nan, 3.0],
... [np.nan, 2.0, 1.0]])
>>> ops.nanmedian(x)
1.5

>>> ops.nanmedian(x, axis=1)
array([2., 1.5])

>>> ops.nanmedian(x, axis=1, keepdims=True)
array([[2. ],
[1.5]])
"""
if any_symbolic_tensors((x,)):
return Nanmedian(axis=axis, keepdims=keepdims).symbolic_call(x)

return backend.numpy.nanmedian(x, axis=axis, keepdims=keepdims)


class Nanmin(Operation):
def __init__(self, axis=None, keepdims=False, *, name=None):
super().__init__(name=name)
Expand Down
91 changes: 91 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1885,6 +1885,29 @@ def test_nanmean(self):
self.assertEqual(knp.nanmean(x4, axis=2).shape, (None, 2, 4))
self.assertEqual(knp.nanmean(x4, axis=(1, 3)).shape, (None, 3))

def test_nanmedian(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.nanmedian(x).shape, ())

x = KerasTensor((None, 3, 3))
self.assertEqual(knp.nanmedian(x, axis=1).shape, (None, 3))
self.assertEqual(
knp.nanmedian(x, axis=1, keepdims=True).shape, (None, 1, 3)
)

self.assertEqual(knp.nanmedian(x, axis=(1,)).shape, (None, 3))

self.assertEqual(knp.nanmedian(x, axis=(1, 2)).shape, (None,))
self.assertEqual(
knp.nanmedian(x, axis=(1, 2), keepdims=True).shape, (None, 1, 1)
)

self.assertEqual(knp.nanmedian(x, axis=()).shape, (None, 3, 3))

x4 = KerasTensor((None, 2, 3, 4))
self.assertEqual(knp.nanmedian(x4, axis=2).shape, (None, 2, 4))
self.assertEqual(knp.nanmedian(x4, axis=(1, 3)).shape, (None, 3))

def test_nanmin(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.nanmin(x).shape, ())
Expand Down Expand Up @@ -2752,6 +2775,13 @@ def test_nanmean(self):
self.assertEqual(knp.nanmean(x, axis=1).shape, (2,))
self.assertEqual(knp.nanmean(x, axis=1, keepdims=True).shape, (2, 1))

def test_nanmedian(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.nanmedian(x).shape, ())
self.assertEqual(knp.nanmedian(x, axis=0).shape, (3,))
self.assertEqual(knp.nanmedian(x, axis=1).shape, (2,))
self.assertEqual(knp.nanmedian(x, axis=1, keepdims=True).shape, (2, 1))

def test_nanmin(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.nanmin(x).shape, ())
Expand Down Expand Up @@ -6551,6 +6581,48 @@ def test_nanmean(self):
np.nanmean(x_3d, axis=(1, 2)),
)

def test_nanmedian(self):
x = np.array(
[[1.0, np.nan, 3.0, 4.0, 5.0], [np.nan, 2.0, 3.0, np.inf, -np.inf]]
)

self.assertAllClose(knp.nanmedian(x), np.nanmedian(x))
self.assertAllClose(knp.nanmedian(x, axis=()), np.nanmedian(x, axis=()))
self.assertAllClose(knp.nanmedian(x, axis=1), np.nanmedian(x, axis=1))
self.assertAllClose(
knp.nanmedian(x, axis=(1,)), np.nanmedian(x, axis=(1,))
)
self.assertAllClose(
knp.nanmedian(x, axis=1, keepdims=True),
np.nanmedian(x, axis=1, keepdims=True),
)

self.assertAllClose(knp.Nanmedian()(x), np.nanmedian(x))
self.assertAllClose(knp.Nanmedian(axis=1)(x), np.nanmedian(x, axis=1))
self.assertAllClose(
knp.Nanmedian(axis=1, keepdims=True)(x),
np.nanmedian(x, axis=1, keepdims=True),
)

x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])
self.assertAllClose(knp.nanmedian(x_all_nan), np.nanmedian(x_all_nan))
self.assertAllClose(
knp.nanmedian(x_all_nan, axis=1),
np.nanmedian(x_all_nan, axis=1),
)

x_3d = np.array(
[
[[1.0, np.nan], [2.0, 3.0]],
[[np.nan, 4.0], [5.0, np.nan]],
]
)
self.assertAllClose(knp.nanmedian(x_3d), np.nanmedian(x_3d))
self.assertAllClose(
knp.nanmedian(x_3d, axis=(1, 2)),
np.nanmedian(x_3d, axis=(1, 2)),
)

def test_nanmin(self):
x = np.array([[1.0, np.nan, 3.0], [np.nan, 2.0, np.inf]])

Expand Down Expand Up @@ -10056,6 +10128,25 @@ def test_nanmean(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_nanmedian(self, dtype):
import jax.numpy as jnp

x = knp.ones((1,), dtype=dtype)
x_jax = jnp.ones((1,), dtype=dtype)
expected_dtype = standardize_dtype(jnp.nanmedian(x_jax).dtype)

if backend.backend() == "torch" and expected_dtype == "uint32":
expected_dtype = "int32"

self.assertEqual(
standardize_dtype(knp.nanmedian(x).dtype), expected_dtype
)
self.assertEqual(
standardize_dtype(knp.Nanmedian().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_nanmin(self, dtype):
import jax.numpy as jnp
Expand Down
Loading