1010from collections .abc import Sequence
1111from functools import partial
1212from numbers import Number
13- from typing import TYPE_CHECKING
13+ from typing import TYPE_CHECKING , Union
1414from typing import cast as type_cast
1515
1616import numpy as np
3333from pytensor .link .c .op import COp
3434from pytensor .link .c .params_type import ParamsType
3535from pytensor .printing import Printer , min_informative_str , pprint , set_precedence
36- from pytensor .raise_op import CheckAndRaise , assert_op
36+ from pytensor .raise_op import CheckAndRaise
3737from pytensor .scalar import int32
3838from pytensor .scalar .basic import ScalarConstant , ScalarType , ScalarVariable
3939from pytensor .tensor import (
@@ -3084,7 +3084,9 @@ def flatten(x, ndim=1):
30843084 return x_reshaped
30853085
30863086
3087- def tile (x , reps , ndim = None ):
3087+ def tile (
3088+ A : "TensorLike" , reps : Union [Sequence [int , "TensorLike" ], "TensorLike" ]
3089+ ) -> TensorVariable :
30883090 """
30893091 Tile input array `x` according to `reps`.
30903092
@@ -3094,77 +3096,62 @@ def tile(x, reps, ndim=None):
30943096 symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector())
30953097 or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).
30963098
3097- ndim is the number of the dimensions of the output, if it is provided, ndim
3098- should be equal or larger than x.ndim and len(reps), otherwise, we will use
3099- max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to
3100- be provided.
3101-
31023099 """
3103- from pytensor .tensor .math import ge
31043100
3105- _x = as_tensor_variable (x )
3106- if ndim is not None and ndim < _x .ndim :
3107- raise ValueError ("ndim should be equal or larger than _x.ndim" )
3101+ A = as_tensor_variable (A )
31083102
3109- # If reps is a scalar, integer or vector, we convert it to a list.
3103+ # Convert symbolic reps to a tuple
31103104 if not isinstance (reps , list | tuple ):
3111- reps_astensor = as_tensor_variable (reps )
3112- ndim_check = reps_astensor .ndim
3113- if reps_astensor .dtype not in discrete_dtypes :
3114- raise ValueError ("elements of reps must be integer dtype" )
3115-
3116- # The scalar/integer case
3117- if ndim_check == 0 :
3118- reps = [reps ]
3119-
3120- # The vector case
3121- elif ndim_check == 1 :
3122- if ndim is None :
3105+ reps = as_tensor_variable (reps )
3106+ if reps .type .ndim == 0 :
3107+ reps = (reps ,)
3108+ elif reps .type .ndim == 1 :
3109+ try :
3110+ reps = tuple (reps )
3111+ except ValueError :
31233112 raise ValueError (
3124- "if reps is tensor.vector, you should specify the ndim "
3113+ "Length of repetitions tensor cannot be determined. Use specify_shape to set the length. "
31253114 )
3126- else :
3127- offset = ndim - reps .shape [0 ]
3128-
3129- # assert that reps.shape[0] does not exceed ndim
3130- offset = assert_op (offset , ge (offset , 0 ))
3115+ else :
3116+ raise ValueError (
3117+ f"Repetitions tensor must be a scalar or a vector, got ndim={ reps .type .ndim } "
3118+ )
31313119
3132- # if reps.ndim is less than _x.ndim, we pad the reps with
3133- # "1" so that reps will have the same ndim as _x.
3134- reps_ = [switch (i < offset , 1 , reps [i - offset ]) for i in range (ndim )]
3135- reps = reps_
3120+ reps = [as_tensor_variable (rep ) for rep in reps ]
3121+ if not all (
3122+ rep .type .ndim == 0 and rep .type .dtype in discrete_dtypes for rep in reps
3123+ ):
3124+ raise ValueError (
3125+ f"All reps entries shoud be scalar integers, got { reps } of type { [rep .type for rep in reps ]} "
3126+ )
31363127
3137- # For others, raise an error
3138- else :
3139- raise ValueError ("the dimension of reps should not exceed 1" )
3140- else :
3141- if ndim is not None and len (reps ) > ndim :
3142- raise ValueError ("len(reps) should be equal or less than ndim" )
3143- if not all (
3144- isinstance (r , int )
3145- or (isinstance (r , TensorVariable ) and r .dtype in discrete_dtypes )
3146- for r in reps
3147- ):
3148- raise ValueError ("elements of reps must be scalars of integer dtype" )
3128+ len_reps = len (reps )
3129+ out_ndim = builtins .max (len_reps , A .type .ndim )
3130+
3131+ # Pad reps on the left (if needed)
3132+ if len_reps < out_ndim :
3133+ reps = (* ((1 ,) * (out_ndim - len_reps )), * reps )
3134+
3135+ # Pad A's shape on the left (if needed)
3136+ elif A .type .ndim < out_ndim :
3137+ A = shape_padleft (A , out_ndim - A .type .ndim )
3138+
3139+ # Expand every other dim of A and expand n-reps via Alloc
3140+ # A_replicated = alloc(A[None, :, ..., None, :], reps[0], A.shape[0], ..., reps[-1], A.shape[-1])
3141+ A_shape = A .shape
3142+ interleaved_reps_shape = [
3143+ d for pair in zip (reps , A .shape , strict = True ) for d in pair
3144+ ]
3145+ every_other_axis = tuple (range (0 , out_ndim * 2 , 2 ))
3146+ A_replicated = alloc (
3147+ expand_dims (A , every_other_axis ),
3148+ * interleaved_reps_shape ,
3149+ )
31493150
3150- # If reps.ndim is less than _x.ndim, we pad the reps with
3151- # "1" so that reps will have the same ndim as _x
3152- reps = list (reps )
3153- if ndim is None :
3154- ndim = builtins .max (len (reps ), _x .ndim )
3155- if len (reps ) < ndim :
3156- reps = [1 ] * (ndim - len (reps )) + reps
3157-
3158- _shape = [1 ] * (ndim - _x .ndim ) + [_x .shape [i ] for i in range (_x .ndim )]
3159- alloc_shape = reps + _shape
3160- y = alloc (_x , * alloc_shape )
3161- shuffle_ind = np .arange (ndim * 2 ).reshape (2 , ndim )
3162- shuffle_ind = shuffle_ind .transpose ().flatten ()
3163- y = y .dimshuffle (* shuffle_ind )
3164- new_shapes = [sh * reps [i ] for i , sh in enumerate (_shape )]
3165- y = y .reshape (new_shapes )
3166-
3167- return y
3151+ # Combine replicate and original dimensions via reshape
3152+ # A_tiled = A_replicated.reshape(reps[0] * A.shape[0], ..., reps[-1] * A.shape[-1])
3153+ tiled_shape = tuple (rep * A_dim for rep , A_dim in zip (reps , A_shape , strict = True ))
3154+ return A_replicated .reshape (tiled_shape )
31683155
31693156
31703157class ARange (Op ):
0 commit comments