Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,18 +2254,19 @@ def infer_shape(self, fgraph, node, in_shapes):
out_shapes.append(temp)
return out_shapes

def connection_pattern(self, node):
n_out = len(node.outputs)
return [
[True] * n_out,
[True] * n_out,
[False] * n_out,
]

def L_op(self, inputs, outputs, g_outputs):
"""Join the gradients along the axis that was used to split x."""
_x, axis, n = inputs
_x, axis, _n = inputs

# If all the output gradients are disconnected, then so are the inputs
if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs):
return [
DisconnectedType()(),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n),
]
# Else, we have to make them zeros before joining them
# We have to convert disconnected outputs to zeros before joining them
new_g_outputs = []
for o, g in zip(outputs, g_outputs, strict=True):
if isinstance(g.type, DisconnectedType):
Expand All @@ -2276,7 +2277,7 @@ def L_op(self, inputs, outputs, g_outputs):
return [
join(axis, *new_g_outputs),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n),
DisconnectedType()(),
]

def R_op(self, inputs, eval_points):
Expand Down
27 changes: 15 additions & 12 deletions pytensor/tensor/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,7 @@ def __init__(self, axis: int):

def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override]
x = as_tensor_variable(x)
shape = as_tensor_variable(shape, dtype=int, ndim=1)

if shape.type.numpy_dtype.kind not in "iu":
raise TypeError("shape must be an integer tensor")

shape = as_tensor_variable(shape).astype(int)
axis = self.axis
_, constant_shape = infer_static_shape(shape)

Expand Down Expand Up @@ -272,19 +268,26 @@ def split_dims(
)
axis = 0

if isinstance(shape, int):
shape = [shape]
else:
shape = list(shape) # type: ignore[arg-type]

if not shape:
if not isinstance(shape, Sequence):
if isinstance(shape, TensorVariable):
if shape.ndim > 1:
raise ValueError(
"If shape is not a sequence, it must be a scalar, 0D, or 1D tensor"
)
elif shape.ndim == 0:
shape = (shape,)
# else: shape.ndim == 1, use as-is
elif np.isscalar(shape):
shape = (shape,) # type: ignore[assignment]

if shape is None or (isinstance(shape, Sequence) and len(shape) == 0):
# If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for
# example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes
# (3, ) and (3, 3) to (3, 4)
return squeeze(x, axis=axis) # type: ignore[no-any-return]

[axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type]
shape = as_tensor_variable(shape, dtype="int64") # type: ignore[arg-type]

return SplitDims(axis=axis)(x, shape) # type: ignore[return-value]

Expand Down
13 changes: 12 additions & 1 deletion tests/tensor/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytensor import config, function
from pytensor import tensor as pt
from pytensor.graph import rewrite_graph, vectorize_graph
from pytensor.graph.op import io_connection_pattern
from pytensor.tensor.reshape import (
_analyze_axes_list,
join_dims,
Expand Down Expand Up @@ -61,9 +62,10 @@ def test_join_dims():
[
(0, pt.as_tensor([2, 3]), (2, 3, 4, 6)),
(2, [2, 3], (6, 4, 2, 3)),
(-1, pt.as_tensor(6), (6, 4, 6)),
(-1, 6, (6, 4, 6)),
],
ids=["tensor", "list", "integer"],
ids=["tensor list", "integer list", "tensor", "integer"],
)
def test_split_dims(axis, shape, expected_shape):
rng = np.random.default_rng()
Expand Down Expand Up @@ -288,3 +290,12 @@ def test_pack_unpack_round_trip(self, axes):

for input_val, output_val in zip(input_dict.values(), output_vals, strict=True):
np.testing.assert_allclose(input_val, output_val)


def test_unpack_connection():
x = pt.vector("x")
d0 = pt.scalar("d0", dtype=int)
d1 = pt.scalar("d1", dtype=int)
x0, x1 = pt.unpack(x, axes=None, packed_shapes=[d0, d1])
out = x0.sum() + x1.sum()
assert io_connection_pattern([x, d0, d1], [out]) == [[True], [False], [False]]
Loading