1212from pytensor .graph .replace import _vectorize_node
1313from pytensor .scalar import ScalarVariable
1414from pytensor .tensor import TensorLike , as_tensor_variable
15- from pytensor .tensor .basic import expand_dims , infer_static_shape , join , split
15+ from pytensor .tensor .basic import infer_static_shape , join , split
1616from pytensor .tensor .math import prod
1717from pytensor .tensor .type import tensor
1818from pytensor .tensor .variable import TensorVariable
2424
2525
2626class JoinDims (Op ):
27- __props__ = (
28- "start_axis" ,
29- "n_axes" ,
30- )
27+ __props__ = ("start_axis" , "n_axes" )
3128 view_map = {0 : [0 ]}
3229
3330 def __init__ (self , start_axis : int , n_axes : int ):
@@ -55,6 +52,8 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override]
5552
5653 static_shapes = x .type .shape
5754 axis_range = self .axis_range
55+ if (self .start_axis + self .n_axes ) > x .type .ndim :
56+ raise ValueError ("JoinDims requested to join axes that are not available" )
5857
5958 joined_shape = (
6059 int (np .prod ([static_shapes [i ] for i in axis_range ]))
@@ -69,9 +68,7 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override]
6968
7069 def infer_shape (self , fgraph , node , shapes ):
7170 [input_shape ] = shapes
72- axis_range = self .axis_range
73-
74- joined_shape = prod ([input_shape [i ] for i in axis_range ])
71+ joined_shape = prod ([input_shape [i ] for i in self .axis_range ], dtype = int )
7572 return [self .output_shapes (input_shape , joined_shape )]
7673
7774 def perform (self , node , inputs , outputs ):
@@ -98,23 +95,24 @@ def L_op(self, inputs, outputs, output_grads):
9895@_vectorize_node .register (JoinDims )
9996def _vectorize_joindims (op , node , x ):
10097 [old_x ] = node .inputs
101-
10298 batched_ndims = x .type .ndim - old_x .type .ndim
103- start_axis = op .start_axis
104- n_axes = op .n_axes
99+ return JoinDims (op .start_axis + batched_ndims , op .n_axes ).make_node (x )
105100
106- return JoinDims (start_axis + batched_ndims , n_axes ).make_node (x )
107101
108-
109- def join_dims (x : TensorLike , axis : Sequence [int ] | int | None = None ) -> TensorVariable :
102+ def join_dims (
103+ x : TensorLike , start_axis : int = 0 , n_axes : int | None = None
104+ ) -> TensorVariable :
110105 """Join consecutive dimensions of a tensor into a single dimension.
111106
112107 Parameters
113108 ----------
114109 x : TensorLike
115110 The input tensor.
116- axis : int or sequence of int, optional
117- The dimensions to join. If None, all dimensions are joined.
111+ start_axis : int, default 0
112+ The axis from which to start joining dimensions
113+ n_axes: int, optional.
114+ The number of axis to join after `axis`. If `None` joins all remaining axis.
115+ If 0, it inserts a new dimension of length 1.
118116
119117 Returns
120118 -------
@@ -125,33 +123,32 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
125123 --------
126124 >>> import pytensor.tensor as pt
127125 >>> x = pt.tensor("x", shape=(2, 3, 4, 5))
128- >>> y = pt.join_dims(x, axis=(1, 2))
126+ >>> y = pt.join_dims(x)
127+ >>> y.type.shape
128+ (120,)
129+ >>> y = pt.join_dims(x, start_axis=1)
130+ >>> y.type.shape
131+ (2, 60)
132+ >>> y = pt.join_dims(x, start_axis=1, n_axes=2)
129133 >>> y.type.shape
130134 (2, 12, 5)
131135 """
132136 x = as_tensor_variable (x )
137+ ndim = x .type .ndim
133138
134- if axis is None :
135- axis = list (range (x .ndim ))
136- elif isinstance (axis , int ):
137- axis = [axis ]
138- elif not isinstance (axis , list | tuple ):
139- raise TypeError ("axis must be an int, a list/tuple of ints, or None" )
140-
141- axis = normalize_axis_tuple (axis , x .ndim )
139+ if start_axis < 0 :
140+ # We treat scalars as if they had a single axis
141+ start_axis += max (1 , ndim )
142142
143- if len (axis ) <= 1 :
144- return x # type: ignore[unreachable]
145-
146- if np .diff (axis ).max () > 1 :
147- raise ValueError (
148- f"join_dims axis must be consecutive, got normalized axis: { axis } "
143+ if not 0 <= start_axis <= ndim :
144+ raise IndexError (
145+ f"Axis { start_axis } is out of bounds for array of dimension { ndim } "
149146 )
150147
151- start_axis = min ( axis )
152- n_axes = len ( axis )
148+ if n_axes is None :
149+ n_axes = ndim - start_axis
153150
154- return JoinDims (start_axis = start_axis , n_axes = n_axes )(x ) # type: ignore[return-value]
151+ return JoinDims (start_axis , n_axes )(x ) # type: ignore[return-value]
155152
156153
157154class SplitDims (Op ):
@@ -213,11 +210,11 @@ def connection_pattern(self, node):
213210 def L_op (self , inputs , outputs , output_grads ):
214211 (x , _ ) = inputs
215212 (g_out ,) = output_grads
216-
217213 n_axes = g_out .ndim - x .ndim + 1
218- axis_range = list (range (self .axis , self .axis + n_axes ))
219-
220- return [join_dims (g_out , axis = axis_range ), disconnected_type ()]
214+ return [
215+ join_dims (g_out , start_axis = self .axis , n_axes = n_axes ),
216+ disconnected_type (),
217+ ]
221218
222219
223220@_vectorize_node .register (SplitDims )
@@ -230,14 +227,13 @@ def _vectorize_splitdims(op, node, x, shape):
230227 if as_tensor_variable (shape ).type .ndim != 1 :
231228 return vectorize_node_fallback (op , node , x , shape )
232229
233- axis = op .axis
234- return SplitDims (axis = axis + batched_ndims ).make_node (x , shape )
230+ return SplitDims (axis = op .axis + batched_ndims ).make_node (x , shape )
235231
236232
237233def split_dims (
238234 x : TensorLike ,
239235 shape : ShapeValueType | Sequence [ShapeValueType ],
240- axis : int | None = None ,
236+ axis : int = 0 ,
241237) -> TensorVariable :
242238 """Split a dimension of a tensor into multiple dimensions.
243239
@@ -247,8 +243,8 @@ def split_dims(
247243 The input tensor.
248244 shape : int or sequence of int
249245 The new shape to split the specified dimension into.
250- axis : int, optional
251- The dimension to split. If None, the input is assumed to be 1D and axis 0 is used.
246+ axis : int, default 0
247+ The dimension to split.
252248
253249 Returns
254250 -------
@@ -259,22 +255,18 @@ def split_dims(
259255 --------
260256 >>> import pytensor.tensor as pt
261257 >>> x = pt.tensor("x", shape=(6, 4, 6))
262- >>> y = pt.split_dims(x, shape=(2, 3), axis=0 )
258+ >>> y = pt.split_dims(x, shape=(2, 3))
263259 >>> y.type.shape
264260 (2, 3, 4, 6)
261+ >>> y = pt.split_dims(x, shape=(2, 3), axis=-1)
262+ >>> y.type.shape
263+ (6, 4, 2, 3)
265264 """
266265 x = as_tensor_variable (x )
267-
268- if axis is None :
269- if x .type .ndim != 1 :
270- raise ValueError (
271- "split_dims can only be called with axis=None for 1d inputs"
272- )
273- axis = 0
274- else :
275- axis = normalize_axis_index (axis , x .ndim )
266+ axis = normalize_axis_index (axis , x .ndim )
276267
277268 # Convert scalar shape to 1d tuple (shape,)
269+ # Which is basically a specify_shape
278270 if not isinstance (shape , Sequence ):
279271 if isinstance (shape , TensorVariable | np .ndarray ):
280272 if shape .ndim == 0 :
@@ -313,8 +305,6 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]:
313305 elif not isinstance (axes , Iterable ):
314306 raise TypeError ("axes must be an int, an iterable of ints, or None" )
315307
316- axes = tuple (axes )
317-
318308 if len (axes ) == 0 :
319309 raise ValueError ("axes=[] is ambiguous; use None to ravel all" )
320310
@@ -465,22 +455,10 @@ def pack(
465455 f"Input { i } (zero indexed) to pack has { n_dim } dimensions, "
466456 f"but { keep_axes = } assumes at least { min_axes } dimension{ 's' if min_axes != 1 else '' } ."
467457 )
468- n_after_packed = n_dim - n_after
469- packed_shapes .append (input_tensor .shape [n_before :n_after_packed ])
470-
471- if n_dim == min_axes :
472- # If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern
473- # implied by the axes.
474- input_tensor = expand_dims (input_tensor , axis = n_before )
475- reshaped_tensors .append (input_tensor )
476- continue
477-
478- # The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1,
479- # shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
480- # rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
481- # corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
482- join_axes = range (n_before , n_after_packed )
483- joined = join_dims (input_tensor , tuple (join_axes ))
458+
459+ n_packed = n_dim - n_after - n_before
460+ packed_shapes .append (input_tensor .shape [n_before : n_before + n_packed ])
461+ joined = join_dims (input_tensor , n_before , n_packed )
484462 reshaped_tensors .append (joined )
485463
486464 return join (n_before , * reshaped_tensors ), packed_shapes
0 commit comments