@@ -90,7 +90,8 @@ def L_op(
9090 (x ,) = inputs
9191 (g_out ,) = output_grads
9292
93- packed_shape = shape (x )[list (self .axis_range )]
93+ x_shape = shape (x )
94+ packed_shape = [x_shape [i ] for i in self .axis_range ]
9495 return [split_dims (g_out , shape = packed_shape , axis = self .start_axis )]
9596
9697
@@ -105,19 +106,19 @@ def _vectorize_joindims(op, node, x):
105106 return JoinDims (start_axis + batched_ndims , n_axes ).make_node (x )
106107
107108
108- def join_dims (x : TensorLike , axis : Sequence [int ] | int | None = None ) -> Variable :
109+ def join_dims (x : TensorLike , axis : Sequence [int ] | int | None = None ) -> TensorVariable :
109110 """Join consecutive dimensions of a tensor into a single dimension.
110111
111112 Parameters
112113 ----------
113- x : Variable
114+ x : TensorLike
114115 The input tensor.
115116 axis : int or sequence of int, optional
116117 The dimensions to join. If None, all dimensions are joined.
117118
118119 Returns
119120 -------
120- joined_x : Variable
121+ joined_x : TensorVariable
121122 The reshaped tensor with joined dimensions.
122123
123124 Examples
@@ -237,7 +238,7 @@ def split_dims(
237238 x : TensorLike ,
238239 shape : ShapeValueType | Sequence [ShapeValueType ],
239240 axis : int | None = None ,
240- ) -> Variable :
241+ ) -> TensorVariable :
241242 """Split a dimension of a tensor into multiple dimensions.
242243
243244 Parameters
@@ -251,7 +252,7 @@ def split_dims(
251252
252253 Returns
253254 -------
254- split_x : Variable
255+ split_x : TensorVariable
255256 The reshaped tensor with split dimensions.
256257
257258 Examples
@@ -384,7 +385,7 @@ def pack(
384385
385386 Returns
386387 -------
387- packed_tensor : TensorLike
388+ packed_tensor : TensorVariable
388389 The packed tensor with specified axes preserved and others raveled.
389390 packed_shapes : list of ShapeValueType
390391 A list containing the shapes of the raveled dimensions for each input tensor.
@@ -492,7 +493,7 @@ def unpack(
492493 packed_input : TensorLike ,
493494 axes : int | Sequence [int ] | None ,
494495 packed_shapes : list [ShapeValueType ],
495- ) -> list [Variable ]:
496+ ) -> list [TensorVariable ]:
496497 """
497498 Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping.
498499
@@ -514,7 +515,7 @@ def unpack(
514515
515516 Returns
516517 -------
517- unpacked_tensors : list of TensorLike
518+ unpacked_tensors : list of TensorVariable
518519 A list of unpacked tensors with their original shapes restored.
519520 """
520521 packed_input = as_tensor_variable (packed_input )
0 commit comments