Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@
from keras.src.ops.numpy import triu as triu
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unique as unique
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
from keras.src.ops.numpy import triu as triu
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unique as unique
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@
from keras.src.ops.numpy import triu as triu
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unique as unique
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
from keras.src.ops.numpy import triu as triu
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unique as unique
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
Expand Down
21 changes: 21 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,3 +1637,24 @@ def argpartition(x, kth, axis=-1):

def histogram(x, bins=10, range=None):
return jnp.histogram(x, bins=bins, range=range)


def unique(
x,
sorted=True,
return_inverse=False,
return_counts=False,
axis=None,
size=None,
fill_value=None,
):
return jnp.unique(
x,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
equal_nan=False,
size=size,
sorted=sorted,
fill_value=fill_value,
)
52 changes: 52 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,3 +1680,55 @@ def argpartition(x, kth, axis=-1):

def histogram(x, bins=10, range=None):
return np.histogram(x, bins=bins, range=range)


def unique(
x,
sorted=True,
return_inverse=False,
return_counts=False,
axis=None,
size=None,
fill_value=None,
):
# Note: np.unique always sorts the output in versions < 2.3.0.
# We accept the 'sorted' argument for API consistency across backends
# but do not pass it to np.unique to avoid TypeError in older versions.
output = np.unique(
x,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
equal_nan=False,
)

if not (return_inverse or return_counts):
output = [output]
else:
output = list(output)

values = output[0]

if size is not None:
dim = axis if axis is not None else 0
values_count = values.shape[dim]

if values_count > size:
# Truncate
indices = [slice(None)] * values.ndim
indices[dim] = slice(0, size)
values = values[tuple(indices)]
if return_counts:
output[-1] = output[-1][tuple(indices)]

elif values_count < size:
# Pad
pad_width = [(0, 0)] * values.ndim
pad_width[dim] = (0, size - values_count)
fill = 0 if fill_value is None else fill_value
values = np.pad(values, pad_width, constant_values=fill)
if return_counts:
output[-1] = np.pad(output[-1], pad_width, constant_values=0)

output[0] = values
return output[0] if len(output) == 1 else tuple(output)
1 change: 1 addition & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ NumpyOneInputOpsCorrectnessTest::test_imag
NumpyOneInputOpsCorrectnessTest::test_isreal
NumpyOneInputOpsCorrectnessTest::test_nanmedian
NumpyOneInputOpsCorrectnessTest::test_real
NumpyOneInputOpsCorrectnessTest::test_unique
QuantizersTest::test_compute_float8_scale
QuantizersTest::test_grouped_quantize_with_padding
RandAugmentTest::test_layer
Expand Down
14 changes: 14 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5573,3 +5573,17 @@ def histogram(x, bins=10, range=None):
)

return OpenVINOKerasTensor(hist.output(0)), OpenVINOKerasTensor(bin_edges)


def unique(
x,
sorted=True,
return_inverse=False,
return_counts=False,
axis=None,
size=None,
fill_value=None,
):
raise NotImplementedError(
"OpenVINO backend does not support the `unique` operation."
)
107 changes: 107 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3761,3 +3761,110 @@ def histogram(x, bins=10, range=None):
shape=(bins,),
)
return bin_counts, bin_edges


def unique(
x,
sorted=True,
return_inverse=False,
return_counts=False,
axis=None,
size=None,
fill_value=None,
):
x = tf.convert_to_tensor(x)
is_flatten = axis is None
original_shape = tf.shape(x)

if is_flatten:
x = tf.reshape(x, [-1])
dim = 0
if return_counts:
y, inverse, counts = tf.unique_with_counts(x)
else:
y, inverse = tf.unique(x)
counts = None
else:
ndim = x.shape.rank
dim = axis + ndim if axis < 0 else axis
axis_to_use = tf.constant([dim], dtype=tf.int32)
y, inverse, counts = tf.raw_ops.UniqueWithCountsV2(
x=x, axis=axis_to_use, out_idx=tf.int32
)
if not return_counts:
counts = None

if sorted:
num_unique = tf.shape(y)[dim]
if is_flatten or y.shape.rank == 1:
sort_order = tf.argsort(y)
else:
# Multi-D lexicographical sort
perm = list(range(y.shape.rank))
perm[0], perm[dim] = perm[dim], perm[0]
y_transposed = tf.transpose(y, perm)
y_2d = tf.reshape(y_transposed, [num_unique, -1])
num_cols = tf.shape(y_2d)[1]

sort_order = tf.range(num_unique, dtype=tf.int32)

def body(i, current_indices):
col = tf.gather(y_2d[:, i], current_indices)
perm_sort = tf.argsort(col, stable=True)
return i - 1, tf.gather(current_indices, perm_sort)

def cond(i, current_indices):
return i >= 0

_, sort_order = tf.while_loop(
cond, body, [num_cols - 1, sort_order], parallel_iterations=1
)

y = tf.gather(y, sort_order, axis=dim)
if return_counts:
counts = tf.gather(counts, sort_order)
if return_inverse:
# Must invert permutation to map inverse indices correctly
inv_perm = tf.math.invert_permutation(sort_order)
inverse = tf.gather(inv_perm, inverse)

# Static size padding/truncation (branchless logic for graph mode safety)
if size is not None:
values_count = tf.shape(y)[dim]

# 1. Truncate using gather
truncate_size = tf.minimum(values_count, size)
y = tf.gather(y, tf.range(truncate_size), axis=dim)
if return_counts:
counts = tf.gather(counts, tf.range(truncate_size))

# 2. Pad using tf.pad (pad_amount = 0 makes it a no-op)
pad_amount = tf.maximum(0, size - values_count)
paddings = tf.zeros([tf.rank(y), 2], dtype=tf.int32)
paddings = tf.tensor_scatter_nd_update(
paddings, [[dim, 1]], [pad_amount]
)

fill = tf.cast(0 if fill_value is None else fill_value, y.dtype)
y = tf.pad(y, paddings, constant_values=fill)

if return_counts:
counts = tf.pad(counts, [[0, pad_amount]], constant_values=0)

# 3. Enforce static shape for JAX/XLA compatibility
static_shape = y.shape.as_list()
static_shape[dim] = size
y.set_shape(static_shape)
if return_counts:
counts.set_shape([size])

if return_inverse and is_flatten:
inverse = tf.reshape(inverse, original_shape)

results = [y]
if return_inverse:
results.append(inverse)
if return_counts:
results.append(counts)

return tuple(results) if len(results) > 1 else results[0]
57 changes: 57 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,3 +2328,60 @@ def set_to_zero(a, i):
def histogram(x, bins=10, range=None):
hist_result = torch.histogram(x, bins=bins, range=range)
return hist_result.hist, hist_result.bin_edges


def unique(
x,
sorted=True,
return_inverse=False,
return_counts=False,
axis=None,
size=None,
fill_value=None,
):
if not isinstance(x, torch.Tensor):
x = torch.as_tensor(x)

output = torch.unique(
x,
sorted=sorted, # Added sorted parameter here
return_inverse=return_inverse,
return_counts=return_counts,
dim=axis,
)

if not (return_inverse or return_counts):
output = [output]
else:
output = list(output)

values = output[0]

if size is not None:
dim = axis if axis is not None else 0
values_count = values.shape[dim]

if values_count > size:
# Truncate
indices = [slice(None)] * values.ndim
indices[dim] = slice(0, size)
values = values[tuple(indices)]
if return_counts:
output[-1] = output[-1][tuple(indices)]

elif values_count < size:
# Pad
diff = size - values_count
pad_width = [0, 0] * values.ndim
# F.pad expects padding from last dim to first
idx = (values.ndim - 1 - dim) * 2
pad_width[idx + 1] = diff
fill = 0 if fill_value is None else fill_value
values = torch.nn.functional.pad(values, pad_width, value=fill)
if return_counts:
output[-1] = torch.nn.functional.pad(
output[-1], pad_width, value=0
)

output[0] = values
return output[0] if len(output) == 1 else tuple(output)
Loading