|
10 | 10 | from collections.abc import Sequence |
11 | 11 | from functools import partial |
12 | 12 | from numbers import Number |
13 | | -from typing import TYPE_CHECKING |
| 13 | +from typing import TYPE_CHECKING, Union |
14 | 14 | from typing import cast as type_cast |
15 | 15 |
|
16 | 16 | import numpy as np |
|
33 | 33 | from pytensor.link.c.op import COp |
34 | 34 | from pytensor.link.c.params_type import ParamsType |
35 | 35 | from pytensor.printing import Printer, min_informative_str, pprint, set_precedence |
36 | | -from pytensor.raise_op import CheckAndRaise, assert_op |
| 36 | +from pytensor.raise_op import CheckAndRaise |
37 | 37 | from pytensor.scalar import int32 |
38 | 38 | from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable |
39 | 39 | from pytensor.tensor import ( |
@@ -3084,87 +3084,72 @@ def flatten(x, ndim=1): |
3084 | 3084 | return x_reshaped |
3085 | 3085 |
|
3086 | 3086 |
|
3087 | | -def tile(x, reps, ndim=None): |
| 3087 | +def tile( |
| 3088 | + A: "TensorLike", reps: Union[Sequence[int, "TensorLike"], "TensorLike"] |
| 3089 | +) -> TensorVariable: |
3088 | 3090 | """ |
3089 | | - Tile input array `x` according to `reps`. |
| 3091 | + Tile input array `A` according to `reps`. |
3090 | 3092 |
|
3091 | 3093 | See the docstring of `numpy.tile` for details. |
3092 | 3094 |
|
3093 | | - 'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]), |
3094 | | - symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector()) |
3095 | | - or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]). |
3096 | | -
|
3097 | | - ndim is the number of the dimensions of the output, if it is provided, ndim |
3098 | | - should be equal or larger than x.ndim and len(reps), otherwise, we will use |
3099 | | - max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to |
3100 | | - be provided. |
3101 | | -
|
| 3095 | + If `reps` is a PyTensor vector, it's length must be statically known. |
| 3096 | + You can use `specify_shape` to set the length. |
3102 | 3097 | """ |
3103 | | - from pytensor.tensor.math import ge |
3104 | 3098 |
|
3105 | | - _x = as_tensor_variable(x) |
3106 | | - if ndim is not None and ndim < _x.ndim: |
3107 | | - raise ValueError("ndim should be equal or larger than _x.ndim") |
| 3099 | + A = as_tensor_variable(A) |
3108 | 3100 |
|
3109 | | - # If reps is a scalar, integer or vector, we convert it to a list. |
| 3101 | + # Convert symbolic reps to a tuple |
3110 | 3102 | if not isinstance(reps, list | tuple): |
3111 | | - reps_astensor = as_tensor_variable(reps) |
3112 | | - ndim_check = reps_astensor.ndim |
3113 | | - if reps_astensor.dtype not in discrete_dtypes: |
3114 | | - raise ValueError("elements of reps must be integer dtype") |
3115 | | - |
3116 | | - # The scalar/integer case |
3117 | | - if ndim_check == 0: |
3118 | | - reps = [reps] |
3119 | | - |
3120 | | - # The vector case |
3121 | | - elif ndim_check == 1: |
3122 | | - if ndim is None: |
| 3103 | + reps = as_tensor_variable(reps) |
| 3104 | + if reps.type.ndim == 0: |
| 3105 | + reps = (reps,) |
| 3106 | + elif reps.type.ndim == 1: |
| 3107 | + try: |
| 3108 | + reps = tuple(reps) |
| 3109 | + except ValueError: |
3123 | 3110 | raise ValueError( |
3124 | | - "if reps is tensor.vector, you should specify the ndim" |
| 3111 | + "Length of repetitions tensor cannot be determined. Use specify_shape to set the length." |
3125 | 3112 | ) |
3126 | | - else: |
3127 | | - offset = ndim - reps.shape[0] |
3128 | | - |
3129 | | - # assert that reps.shape[0] does not exceed ndim |
3130 | | - offset = assert_op(offset, ge(offset, 0)) |
| 3113 | + else: |
| 3114 | + raise ValueError( |
| 3115 | + f"Repetitions tensor must be a scalar or a vector, got ndim={reps.type.ndim}" |
| 3116 | + ) |
3131 | 3117 |
|
3132 | | - # if reps.ndim is less than _x.ndim, we pad the reps with |
3133 | | - # "1" so that reps will have the same ndim as _x. |
3134 | | - reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)] |
3135 | | - reps = reps_ |
| 3118 | + reps = [as_tensor_variable(rep) for rep in reps] |
| 3119 | + if not all( |
| 3120 | + rep.type.ndim == 0 and rep.type.dtype in discrete_dtypes for rep in reps |
| 3121 | + ): |
| 3122 | + raise ValueError( |
| 3123 | + f"All reps entries shoud be scalar integers, got {reps} of type {[rep.type for rep in reps]}" |
| 3124 | + ) |
3136 | 3125 |
|
3137 | | - # For others, raise an error |
3138 | | - else: |
3139 | | - raise ValueError("the dimension of reps should not exceed 1") |
3140 | | - else: |
3141 | | - if ndim is not None and len(reps) > ndim: |
3142 | | - raise ValueError("len(reps) should be equal or less than ndim") |
3143 | | - if not all( |
3144 | | - isinstance(r, int) |
3145 | | - or (isinstance(r, TensorVariable) and r.dtype in discrete_dtypes) |
3146 | | - for r in reps |
3147 | | - ): |
3148 | | - raise ValueError("elements of reps must be scalars of integer dtype") |
| 3126 | + len_reps = len(reps) |
| 3127 | + out_ndim = builtins.max(len_reps, A.type.ndim) |
| 3128 | + |
| 3129 | + # Pad reps on the left (if needed) |
| 3130 | + if len_reps < out_ndim: |
| 3131 | + reps = (*((1,) * (out_ndim - len_reps)), *reps) |
| 3132 | + |
| 3133 | + # Pad A's shape on the left (if needed) |
| 3134 | + elif A.type.ndim < out_ndim: |
| 3135 | + A = shape_padleft(A, out_ndim - A.type.ndim) |
| 3136 | + |
| 3137 | + # Expand every other dim of A and expand n-reps via Alloc |
| 3138 | + # A_replicated = alloc(A[None, :, ..., None, :], reps[0], A.shape[0], ..., reps[-1], A.shape[-1]) |
| 3139 | + A_shape = A.shape |
| 3140 | + interleaved_reps_shape = [ |
| 3141 | + d for pair in zip(reps, A.shape, strict=True) for d in pair |
| 3142 | + ] |
| 3143 | + every_other_axis = tuple(range(0, out_ndim * 2, 2)) |
| 3144 | + A_replicated = alloc( |
| 3145 | + expand_dims(A, every_other_axis), |
| 3146 | + *interleaved_reps_shape, |
| 3147 | + ) |
3149 | 3148 |
|
3150 | | - # If reps.ndim is less than _x.ndim, we pad the reps with |
3151 | | - # "1" so that reps will have the same ndim as _x |
3152 | | - reps = list(reps) |
3153 | | - if ndim is None: |
3154 | | - ndim = builtins.max(len(reps), _x.ndim) |
3155 | | - if len(reps) < ndim: |
3156 | | - reps = [1] * (ndim - len(reps)) + reps |
3157 | | - |
3158 | | - _shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)] |
3159 | | - alloc_shape = reps + _shape |
3160 | | - y = alloc(_x, *alloc_shape) |
3161 | | - shuffle_ind = np.arange(ndim * 2).reshape(2, ndim) |
3162 | | - shuffle_ind = shuffle_ind.transpose().flatten() |
3163 | | - y = y.dimshuffle(*shuffle_ind) |
3164 | | - new_shapes = [sh * reps[i] for i, sh in enumerate(_shape)] |
3165 | | - y = y.reshape(new_shapes) |
3166 | | - |
3167 | | - return y |
| 3149 | + # Combine replicate and original dimensions via reshape |
| 3150 | + # A_tiled = A_replicated.reshape(reps[0] * A.shape[0], ..., reps[-1] * A.shape[-1]) |
| 3151 | + tiled_shape = tuple(rep * A_dim for rep, A_dim in zip(reps, A_shape, strict=True)) |
| 3152 | + return A_replicated.reshape(tiled_shape) |
3168 | 3153 |
|
3169 | 3154 |
|
3170 | 3155 | class ARange(Op): |
|
0 commit comments