Skip to content

Commit fb4b0b5

Browse files
author
Jesse Grabowski
committed
Feedback
1 parent 54c9e69 commit fb4b0b5

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

pytensor/tensor/reshape.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def L_op(
9090
(x,) = inputs
9191
(g_out,) = output_grads
9292

93-
packed_shape = shape(x)[list(self.axis_range)]
93+
x_shape = shape(x)
94+
packed_shape = [x_shape[i] for i in self.axis_range]
9495
return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)]
9596

9697

@@ -105,19 +106,19 @@ def _vectorize_joindims(op, node, x):
105106
return JoinDims(start_axis + batched_ndims, n_axes).make_node(x)
106107

107108

108-
def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> Variable:
109+
def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable:
109110
"""Join consecutive dimensions of a tensor into a single dimension.
110111
111112
Parameters
112113
----------
113-
x : Variable
114+
x : TensorLike
114115
The input tensor.
115116
axis : int or sequence of int, optional
116117
The dimensions to join. If None, all dimensions are joined.
117118
118119
Returns
119120
-------
120-
joined_x : Variable
121+
joined_x : TensorVariable
121122
The reshaped tensor with joined dimensions.
122123
123124
Examples
@@ -237,7 +238,7 @@ def split_dims(
237238
x: TensorLike,
238239
shape: ShapeValueType | Sequence[ShapeValueType],
239240
axis: int | None = None,
240-
) -> Variable:
241+
) -> TensorVariable:
241242
"""Split a dimension of a tensor into multiple dimensions.
242243
243244
Parameters
@@ -251,7 +252,7 @@ def split_dims(
251252
252253
Returns
253254
-------
254-
split_x : Variable
255+
split_x : TensorVariable
255256
The reshaped tensor with split dimensions.
256257
257258
Examples
@@ -384,7 +385,7 @@ def pack(
384385
385386
Returns
386387
-------
387-
packed_tensor : TensorLike
388+
packed_tensor : TensorVariable
388389
The packed tensor with specified axes preserved and others raveled.
389390
packed_shapes : list of ShapeValueType
390391
A list containing the shapes of the raveled dimensions for each input tensor.
@@ -492,7 +493,7 @@ def unpack(
492493
packed_input: TensorLike,
493494
axes: int | Sequence[int] | None,
494495
packed_shapes: list[ShapeValueType],
495-
) -> list[Variable]:
496+
) -> list[TensorVariable]:
496497
"""
497498
Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping.
498499
@@ -514,7 +515,7 @@ def unpack(
514515
515516
Returns
516517
-------
517-
unpacked_tensors : list of TensorLike
518+
unpacked_tensors : list of TensorVariable
518519
A list of unpacked tensors with their original shapes restored.
519520
"""
520521
packed_input = as_tensor_variable(packed_input)

0 commit comments

Comments
 (0)