Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 38 additions & 59 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.shape import Shape_i, specify_broadcastable
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -1194,23 +1194,22 @@ def __init__(
self.return_index = return_index
self.return_inverse = return_inverse
self.return_counts = return_counts
if axis is not None and axis < 0:
raise ValueError("Axis cannot be negative.")
Comment on lines +1197 to +1198
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this possibility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it simplifies the logic in the Op. The helper users use pt.unique handles the negative axis and passes a positive one to the Op. Users don't really create Ops themselves

self.axis = axis

def make_node(self, x):
x = ptb.as_tensor_variable(x)
self_axis = self.axis
if self_axis is None:
axis = self.axis
if axis is None:
out_shape = (None,)
else:
if self_axis < 0:
self_axis += x.type.ndim
if self_axis < 0 or self_axis >= x.type.ndim:
if axis >= x.type.ndim:
raise ValueError(
f"Unique axis {self.axis} is outside of input ndim = {x.type.ndim}"
f"Axis {axis} out of range for input {x} with ndim={x.type.ndim}."
)
out_shape = tuple(
s if s == 1 and axis != self_axis else None
for axis, s in enumerate(x.type.shape)
None if dim == axis else s for dim, s in enumerate(x.type.shape)
)

outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
Expand All @@ -1224,60 +1223,37 @@ def make_node(self, x):
return Apply(self, [x], outputs)

def perform(self, node, inputs, output_storage):
x = inputs[0]
z = output_storage
param = {}
if self.return_index:
param["return_index"] = True
if self.return_inverse:
param["return_inverse"] = True
if self.return_counts:
param["return_counts"] = True
if self.axis is not None:
param["axis"] = self.axis
outs = np.unique(x, **param)
if (
(not self.return_inverse)
and (not self.return_index)
and (not self.return_counts)
):
z[0][0] = outs
else:
[x] = inputs
outs = np.unique(
x,
return_index=self.return_index,
return_inverse=self.return_inverse,
return_counts=self.return_counts,
axis=self.axis,
)
if isinstance(outs, tuple):
for i in range(len(outs)):
z[i][0] = outs[i]
output_storage[i][0] = outs[i]
else:
output_storage[0][0] = outs

def infer_shape(self, fgraph, node, i0_shapes):
ret = fgraph.shape_feature.default_infer_shape(fgraph, node, i0_shapes)
if self.axis is not None:
self_axis = self.axis
ndim = len(i0_shapes[0])
if self_axis < 0:
self_axis += ndim
if self_axis < 0 or self_axis >= ndim:
raise RuntimeError(
f"Unique axis `{self.axis}` is outside of input ndim = {ndim}."
)
ret[0] = tuple(
fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim)
)
[x_shape] = i0_shapes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand what is happening in this function, but just to check, shouldn't there be a case for return_index and return_counts as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is. i0_shapes are the input dimensions, so that doesn't change with the number of outputs. return_index/counts are outputs, and they are always vector.

We set out_shapes = [out.shape[0] for out in node.outputs] by default which will always work for return_index and return_counts. Then we have special logic for the main output when axis is not None and for return_inverse which is not just out.shape[0].

Copy link
Member Author

@ricardoV94 ricardoV94 Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The big picture is the function tries to return a graph for the shape of the outputs given the input shapes (and possibly values, which you could retrieve from node.inputs). The default graph of the shape is just output.shape, which we try to avoid when possible, as we would like to avoid computing the Op in order to find out its shape.

For unique we can do that for some of the outputs dimensions, but not all (we only know how many repeated values there are if we evaluate Unique).

This method is combining dims we can know from the input shapes and those that we can only get after we compute the outputs with out.shape[0] or out.shape[x].

shape0_op = Shape_i(0)
out_shapes = [(shape0_op(out),) for out in node.outputs]

axis = self.axis
if axis is not None:
shape = list(x_shape)
shape[axis] = Shape_i(axis)(node.outputs[0])
out_shapes[0] = tuple(shape)

if self.return_inverse:
if self.axis is None:
shape = (prod(i0_shapes[0]),)
else:
shape = (i0_shapes[0][self_axis],)
if self.return_index:
ret[2] = shape
return ret
ret[1] = shape
return ret
return ret

def __setstate__(self, state):
self.__dict__.update(state)
# For backwards compatibility with pickled instances of Unique that
# did not have the axis parameter specified
if "axis" not in state:
self.axis = None
shape = prod(x_shape) if self.axis is None else x_shape[axis]
return_index_out_idx = 2 if self.return_index else 1
out_shapes[return_index_out_idx] = (shape,)

return out_shapes


def unique(
Expand All @@ -1293,6 +1269,9 @@ def unique(
* the number of times each unique value comes up in the input array

"""
ar = as_tensor_variable(ar)
if axis is not None:
axis = normalize_axis_index(axis, ar.ndim)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Armavica here is where we allow negative axis for the user

return Unique(return_index, return_inverse, return_counts, axis)(ar)


Expand Down