1212from pytensor .graph .replace import _vectorize_node
1313from pytensor .scalar import ScalarVariable
1414from pytensor .tensor import TensorLike , as_tensor_variable
15- from pytensor .tensor .basic import expand_dims , infer_static_shape , join , split
15+ from pytensor .tensor .basic import infer_static_shape , join , split
1616from pytensor .tensor .math import prod
1717from pytensor .tensor .type import tensor
1818from pytensor .tensor .variable import TensorVariable
2424
2525
2626class JoinDims (Op ):
27- __props__ = (
28- "start_axis" ,
29- "n_axes" ,
30- )
27+ __props__ = ("start_axis" , "n_axes" )
3128 view_map = {0 : [0 ]}
3229
3330 def __init__ (self , start_axis : int , n_axes : int ):
@@ -55,6 +52,8 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override]
5552
5653 static_shapes = x .type .shape
5754 axis_range = self .axis_range
55+ if (self .start_axis + self .n_axes ) > x .type .ndim :
56+ raise ValueError ("JoinDims requested to join axes that are not available" )
5857
5958 joined_shape = (
6059 int (np .prod ([static_shapes [i ] for i in axis_range ]))
@@ -69,9 +68,7 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override]
6968
7069 def infer_shape (self , fgraph , node , shapes ):
7170 [input_shape ] = shapes
72- axis_range = self .axis_range
73-
74- joined_shape = prod ([input_shape [i ] for i in axis_range ], dtype = int )
71+ joined_shape = prod ([input_shape [i ] for i in self .axis_range ], dtype = int )
7572 return [self .output_shapes (input_shape , joined_shape )]
7673
7774 def perform (self , node , inputs , outputs ):
@@ -98,23 +95,24 @@ def L_op(self, inputs, outputs, output_grads):
9895@_vectorize_node .register (JoinDims )
9996def _vectorize_joindims (op , node , x ):
10097 [old_x ] = node .inputs
101-
10298 batched_ndims = x .type .ndim - old_x .type .ndim
103- start_axis = op .start_axis
104- n_axes = op .n_axes
99+ return JoinDims (op .start_axis + batched_ndims , op .n_axes ).make_node (x )
105100
106- return JoinDims (start_axis + batched_ndims , n_axes ).make_node (x )
107101
108-
109- def join_dims (x : TensorLike , axis : Sequence [int ] | int | None = None ) -> TensorVariable :
102+ def join_dims (
103+ x : TensorLike , start_axis : int = 0 , n_axes : int | None = None
104+ ) -> TensorVariable :
110105 """Join consecutive dimensions of a tensor into a single dimension.
111106
112107 Parameters
113108 ----------
114109 x : TensorLike
115110 The input tensor.
116- axis : int or sequence of int, optional
117- The dimensions to join. If None, all dimensions are joined.
111+ start_axis : int, default 0
112+ The axis from which to start joining dimensions
113+ n_axes: int, optional.
114+ The number of axis to join after `axis`. If `None` joins all remaining axis.
115+ If 0, it inserts a new dimension of length 1.
118116
119117 Returns
120118 -------
@@ -125,33 +123,31 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
125123 --------
126124 >>> import pytensor.tensor as pt
127125 >>> x = pt.tensor("x", shape=(2, 3, 4, 5))
128- >>> y = pt.join_dims(x, axis=(1, 2))
126+ >>> y = pt.join_dims(x)
127+ >>> y.type.shape
128+ (120,)
129+ >>> y = pt.join_dims(x, start_axis=1)
130+ >>> y.type.shape
131+ (2, 60)
132+ >>> y = pt.join_dims(x, start_axis=1, n_axes=2)
129133 >>> y.type.shape
130134 (2, 12, 5)
131135 """
132- x = as_tensor_variable (x )
133-
134- if axis is None :
135- axis = list (range (x .ndim ))
136- elif isinstance (axis , int ):
137- axis = [axis ]
138- elif not isinstance (axis , list | tuple ):
139- raise TypeError ("axis must be an int, a list/tuple of ints, or None" )
136+ ndim = x .ndim
140137
141- axis = normalize_axis_tuple (axis , x .ndim )
138+ if start_axis < 0 :
139+ # We treat scalars as if they had a single axis
140+ start_axis += max (1 , ndim )
142141
143- if len (axis ) <= 1 :
144- return x # type: ignore[unreachable]
145-
146- if np .diff (axis ).max () > 1 :
147- raise ValueError (
148- f"join_dims axis must be consecutive, got normalized axis: { axis } "
142+ if not 0 <= start_axis <= ndim :
143+ raise IndexError (
144+ f"Axis { start_axis } is out of bounds for array of dimension { ndim } "
149145 )
150146
151- start_axis = min ( axis )
152- n_axes = len ( axis )
147+ if n_axes is None :
148+ n_axes = ndim - start_axis
153149
154- return JoinDims (start_axis = start_axis , n_axes = n_axes )(x ) # type: ignore[return-value]
150+ return JoinDims (start_axis , n_axes )(x )
155151
156152
157153class SplitDims (Op ):
@@ -213,11 +209,11 @@ def connection_pattern(self, node):
213209 def L_op (self , inputs , outputs , output_grads ):
214210 (x , _ ) = inputs
215211 (g_out ,) = output_grads
216-
217212 n_axes = g_out .ndim - x .ndim + 1
218- axis_range = list (range (self .axis , self .axis + n_axes ))
219-
220- return [join_dims (g_out , axis = axis_range ), disconnected_type ()]
213+ return [
214+ join_dims (g_out , start_axis = self .axis , n_axes = n_axes ),
215+ disconnected_type (),
216+ ]
221217
222218
223219@_vectorize_node .register (SplitDims )
@@ -230,14 +226,13 @@ def _vectorize_splitdims(op, node, x, shape):
230226 if as_tensor_variable (shape ).type .ndim != 1 :
231227 return vectorize_node_fallback (op , node , x , shape )
232228
233- axis = op .axis
234- return SplitDims (axis = axis + batched_ndims ).make_node (x , shape )
229+ return SplitDims (axis = op .axis + batched_ndims ).make_node (x , shape )
235230
236231
237232def split_dims (
238233 x : TensorLike ,
239234 shape : ShapeValueType | Sequence [ShapeValueType ],
240- axis : int | None = None ,
235+ axis : int = 0 ,
241236) -> TensorVariable :
242237 """Split a dimension of a tensor into multiple dimensions.
243238
@@ -247,8 +242,8 @@ def split_dims(
247242 The input tensor.
248243 shape : int or sequence of int
249244 The new shape to split the specified dimension into.
250- axis : int, optional
251- The dimension to split. If None, the input is assumed to be 1D and axis 0 is used.
245+ axis : int, default 0
246+ The dimension to split.
252247
253248 Returns
254249 -------
@@ -259,22 +254,18 @@ def split_dims(
259254 --------
260255 >>> import pytensor.tensor as pt
261256 >>> x = pt.tensor("x", shape=(6, 4, 6))
262- >>> y = pt.split_dims(x, shape=(2, 3), axis=0 )
257+ >>> y = pt.split_dims(x, shape=(2, 3))
263258 >>> y.type.shape
264259 (2, 3, 4, 6)
260+ >>> y = pt.split_dims(x, shape=(2, 3), axis=-1)
261+ >>> y.type.shape
262+ (6, 4, 2, 3)
265263 """
266264 x = as_tensor_variable (x )
267-
268- if axis is None :
269- if x .type .ndim != 1 :
270- raise ValueError (
271- "split_dims can only be called with axis=None for 1d inputs"
272- )
273- axis = 0
274- else :
275- axis = normalize_axis_index (axis , x .ndim )
265+ axis = normalize_axis_index (axis , x .ndim )
276266
277267 # Convert scalar shape to 1d tuple (shape,)
268+ # Which is basically a specify_shape
278269 if not isinstance (shape , Sequence ):
279270 if isinstance (shape , TensorVariable | np .ndarray ):
280271 if shape .ndim == 0 :
@@ -313,8 +304,6 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]:
313304 elif not isinstance (axes , Iterable ):
314305 raise TypeError ("axes must be an int, an iterable of ints, or None" )
315306
316- axes = tuple (axes )
317-
318307 if len (axes ) == 0 :
319308 raise ValueError ("axes=[] is ambiguous; use None to ravel all" )
320309
@@ -464,22 +453,10 @@ def pack(
464453 f"Input { i } (zero indexed) to pack has { n_dim } dimensions, "
465454 f"but axes={ axes } assumes at least { min_axes } dimension{ 's' if min_axes != 1 else '' } ."
466455 )
467- n_after_packed = n_dim - n_after
468- packed_shapes .append (input_tensor .shape [n_before :n_after_packed ])
469-
470- if n_dim == min_axes :
471- # If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern
472- # implied by the axes.
473- input_tensor = expand_dims (input_tensor , axis = n_before )
474- reshaped_tensors .append (input_tensor )
475- continue
476-
477- # The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1,
478- # shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
479- # rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
480- # corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
481- join_axes = range (n_before , n_after_packed )
482- joined = join_dims (input_tensor , tuple (join_axes ))
456+
457+ n_packed = n_dim - n_after - n_before
458+ packed_shapes .append (input_tensor .shape [n_before : n_before + n_packed ])
459+ joined = join_dims (input_tensor , n_before , n_packed )
483460 reshaped_tensors .append (joined )
484461
485462 return join (n_before , * reshaped_tensors ), packed_shapes
0 commit comments