Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions pytensor/tensor/reshape.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from collections.abc import Iterable, Sequence
from itertools import pairwise
from typing import cast as type_cast

import numpy as np
from numpy.lib._array_utils_impl import normalize_axis_tuple

from pytensor import Variable
from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import prod
from pytensor.tensor.shape import ShapeValueType
from pytensor.tensor.shape import ShapeValueType, shape
from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable

Expand Down Expand Up @@ -80,6 +81,19 @@ def perform(self, node, inputs, outputs):

out[0] = x.reshape(output_shape)

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
(x,) = inputs
(g_out,) = output_grads

x_shape = shape(x)
packed_shape = [x_shape[i] for i in self.axis_range]
return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)]


@_vectorize_node.register(JoinDims)
def _vectorize_joindims(op, node, x):
Expand All @@ -97,14 +111,14 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV

Parameters
----------
x : Variable
x : TensorLike
The input tensor.
axis : int or sequence of int, optional
The dimensions to join. If None, all dimensions are joined.

Returns
-------
joined_x : Variable
joined_x : TensorVariable
The reshaped tensor with joined dimensions.

Examples
Expand Down Expand Up @@ -137,10 +151,7 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
start_axis = min(axis)
n_axes = len(axis)

return type_cast(
TensorVariable,
JoinDims(start_axis=start_axis, n_axes=n_axes)(x),
)
return JoinDims(start_axis=start_axis, n_axes=n_axes)(x) # type: ignore[return-value]


class SplitDims(Op):
Expand Down Expand Up @@ -191,6 +202,23 @@ def perform(self, node, inputs, outputs):

out[0] = x.reshape(output_shape)

def connection_pattern(self, node):
return [[True], [False]]

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
(x, _) = inputs
(g_out,) = output_grads

n_axes = g_out.ndim - x.ndim + 1 # type: ignore[attr-defined]
axis_range = list(range(self.axis, self.axis + n_axes))

return [join_dims(g_out, axis=axis_range), DisconnectedType()()]


@_vectorize_node.register(SplitDims)
def _vectorize_splitdims(op, node, x, shape):
Expand Down Expand Up @@ -224,7 +252,7 @@ def split_dims(

Returns
-------
split_x : Variable
split_x : TensorVariable
The reshaped tensor with split dimensions.

Examples
Expand Down Expand Up @@ -253,13 +281,12 @@ def split_dims(
# If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for
# example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes
# (3, ) and (3, 3) to (3, 4)
return type_cast(TensorVariable, x.squeeze(axis=axis))
return squeeze(x, axis=axis) # type: ignore[no-any-return]

[axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type]

split_op = SplitDims(axis=axis)
return type_cast(TensorVariable, split_op(x, shape))
return SplitDims(axis=axis)(x, shape) # type: ignore[return-value]


def _analyze_axes_list(axes) -> tuple[int, int, int]:
Expand Down Expand Up @@ -358,7 +385,7 @@ def pack(

Returns
-------
packed_tensor : TensorLike
packed_tensor : TensorVariable
The packed tensor with specified axes preserved and others raveled.
packed_shapes : list of ShapeValueType
A list containing the shapes of the raveled dimensions for each input tensor.
Expand Down Expand Up @@ -430,7 +457,7 @@ def pack(

n_before, n_after, min_axes = _analyze_axes_list(axes)

reshaped_tensors: list[TensorVariable] = []
reshaped_tensors: list[Variable] = []
packed_shapes: list[ShapeValueType] = []

for i, input_tensor in enumerate(tensor_list):
Expand Down Expand Up @@ -488,7 +515,7 @@ def unpack(

Returns
-------
unpacked_tensors : list of TensorLike
unpacked_tensors : list of TensorVariable
A list of unpacked tensors with their original shapes restored.
"""
packed_input = as_tensor_variable(packed_input)
Expand Down
7 changes: 6 additions & 1 deletion tests/tensor/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

import pytensor
import tests.unittest_tools as utt
from pytensor import config, function
from pytensor import tensor as pt
from pytensor.graph import rewrite_graph, vectorize_graph
Expand Down Expand Up @@ -52,6 +53,8 @@ def test_join_dims():
x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX)
assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15)

utt.verify_grad(lambda x: join_dims(x, axis=(1, 2)), [x_value])


@pytest.mark.parametrize(
"axis, shape, expected_shape",
Expand All @@ -77,6 +80,8 @@ def test_split_dims(axis, shape, expected_shape):
x_split_value = fn(x_value)
np.testing.assert_allclose(x_split_value, x_value.reshape(expected_shape))

utt.verify_grad(lambda x: split_dims(x, shape=shape, axis=axis), [x_value])

x = pt.tensor("x", shape=(10,))
x_split = split_dims(x, shape=(5, 2), axis=0)
x_batched = pt.tensor("x_batched", shape=(3, 10))
Expand Down Expand Up @@ -115,7 +120,7 @@ def test_make_replacements_with_pack_unpack():
new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes)

loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
rewrite_graph(loss, include=("ShapeOpt", "specialize"))
rewrite_graph(loss, include=("ShapeOpt", "canonicalize"))

fn = pytensor.function([new_input], loss, mode="FAST_COMPILE")

Expand Down
Loading