|
50 | 50 | tensor, |
51 | 51 | uint_dtypes, |
52 | 52 | ) |
53 | | -from pytensor.tensor.utils import as_list, normalize_reduce_axis |
| 53 | +from pytensor.tensor.utils import normalize_reduce_axis |
54 | 54 | from pytensor.tensor.variable import ( |
55 | 55 | TensorVariable, |
56 | 56 | _tensor_py_operators, |
@@ -3208,133 +3208,6 @@ def dense_dot(a, b): |
3208 | 3208 | return _dot(a, b) |
3209 | 3209 |
|
3210 | 3210 |
|
3211 | | -def _tensordot_as_dot(a, b, axes, dot, batched): |
3212 | | - """ |
3213 | | - Reduces a tensor dot product to a matrix or vector dot product. Based |
3214 | | - on code from Tijmen Tieleman's gnumpy |
3215 | | - (http://www.cs.toronto.edu/~tijmen/gnumpy.html). |
3216 | | -
|
3217 | | - Please see the documentation of tensordot for the meaning of the a, b |
3218 | | - and axes arguments. |
3219 | | -
|
3220 | | - :param dot: a function that accepts two symbolic variables and computes |
3221 | | - the appropriate dot product (e.g. dot, batched_dot) |
3222 | | - :type dot: function |
3223 | | -
|
3224 | | - :param batched: whether to treat the first axis of a and b as a batch |
3225 | | - axis. If so, this axis will be preserved in the output, |
3226 | | - allowing this function to be used also for batched |
3227 | | - tensor dot products. |
3228 | | - :type batched: boolean |
3229 | | -
|
3230 | | - :returns: a tensor with shape equal to the concatenation of a's shape |
3231 | | - (less any dimensions that were summed over) and b's shape |
3232 | | - (less the first dimension and any dimensions that were summed |
3233 | | - over). |
3234 | | - :rtype: symbolic tensor |
3235 | | - """ |
3236 | | - a, b = as_tensor_variable(a), as_tensor_variable(b) |
3237 | | - |
3238 | | - if not np.isscalar(axes) and len(axes) != 2: |
3239 | | - raise ValueError( |
3240 | | - "Axes should be an integer or a " |
3241 | | - f"list/tuple of len 2 ({axes} was provided)" |
3242 | | - ) |
3243 | | - |
3244 | | - # if 'axes' is a number of axes to multiply and sum over (trailing axes |
3245 | | - # of a, leading axes of b), we can just reshape and use dot. |
3246 | | - elif np.isscalar(axes): |
3247 | | - axes = int(axes) |
3248 | | - |
3249 | | - for operand_name, operand in (("a", a), ("b", b)): |
3250 | | - if axes > operand.ndim: |
3251 | | - raise ValueError( |
3252 | | - f"axes can not be larger than the dimension of {operand_name} " |
3253 | | - f"({operand_name}.ndim={operand.ndim}, axes={axes})" |
3254 | | - ) |
3255 | | - if batched and axes == operand.ndim: |
3256 | | - raise ValueError( |
3257 | | - "axes to sum over must not include the batch axis " |
3258 | | - f"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})" |
3259 | | - ) |
3260 | | - |
3261 | | - batch_axes = 1 if batched else 0 |
3262 | | - a_outaxes = slice(0, a.ndim - axes) |
3263 | | - b_outaxes = slice(batch_axes + axes, b.ndim) |
3264 | | - outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]]) |
3265 | | - outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes] |
3266 | | - outndim = len(outbcast) |
3267 | | - |
3268 | | - a_shape = [1] * 2 |
3269 | | - b_shape = [1] * 2 |
3270 | | - |
3271 | | - # compute total size of summed axes |
3272 | | - for i in range(0, axes): |
3273 | | - a_shape[1] *= a.shape[-(i + 1)] |
3274 | | - b_shape[0] *= b.shape[batch_axes + i] |
3275 | | - # compute total size of other axes |
3276 | | - for i in range(0, a.ndim - axes - batch_axes): |
3277 | | - a_shape[0] *= a.shape[batch_axes + i] |
3278 | | - for i in range(0, b.ndim - axes - batch_axes): |
3279 | | - b_shape[1] *= b.shape[-(i + 1)] |
3280 | | - |
3281 | | - if batched: |
3282 | | - a_shape.insert(0, a.shape[0]) |
3283 | | - b_shape.insert(0, b.shape[0]) |
3284 | | - |
3285 | | - a_reshaped = a.reshape(a_shape) |
3286 | | - b_reshaped = b.reshape(b_shape) |
3287 | | - |
3288 | | - out_reshaped = dot(a_reshaped, b_reshaped) |
3289 | | - out = out_reshaped.reshape(outshape, ndim=outndim) |
3290 | | - # Make sure the broadcastable pattern of the result is correct, |
3291 | | - # since some shape information can be lost in the reshapes. |
3292 | | - if out.type.broadcastable != outbcast: |
3293 | | - out = specify_broadcastable( |
3294 | | - out, *(ax for (ax, b) in enumerate(outbcast) if b) |
3295 | | - ) |
3296 | | - return out |
3297 | | - |
3298 | | - # if 'axes' is a list, transpose a and b such that the summed axes of a |
3299 | | - # are last and the summed axes of b are first. |
3300 | | - else: |
3301 | | - axes = [as_list(axes_) for axes_ in axes] |
3302 | | - |
3303 | | - if len(axes[0]) != len(axes[1]): |
3304 | | - raise ValueError("Axes elements must have the same length.") |
3305 | | - |
3306 | | - for i, (operand_name, operand) in enumerate((("a", a), ("b", b))): |
3307 | | - if len(axes[i]) > operand.ndim: |
3308 | | - raise ValueError( |
3309 | | - f"axes[{i}] should be array_like with length less than " |
3310 | | - f"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})." |
3311 | | - ) |
3312 | | - if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim: |
3313 | | - raise ValueError( |
3314 | | - f"axes[{i}] contains dimensions greater than or equal " |
3315 | | - f"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})." |
3316 | | - ) |
3317 | | - if batched and 0 in axes[i]: |
3318 | | - raise ValueError( |
3319 | | - "axes to sum over must not contain the batch axis " |
3320 | | - f"(axes[{i}]={axes[i]})" |
3321 | | - ) |
3322 | | - |
3323 | | - batch_axes = [0] if batched else [] |
3324 | | - other_axes = [ |
3325 | | - [x for x in range(operand.ndim) if x not in axes[i] and x not in batch_axes] |
3326 | | - for i, operand in enumerate((a, b)) |
3327 | | - ] |
3328 | | - |
3329 | | - a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0]) |
3330 | | - b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1]) |
3331 | | - |
3332 | | - # now that a and b are in the right order, recur with integer axes |
3333 | | - return _tensordot_as_dot( |
3334 | | - a_shuffled, b_shuffled, len(axes[0]), dot=dot, batched=batched |
3335 | | - ) |
3336 | | - |
3337 | | - |
3338 | 3211 | def tensordot( |
3339 | 3212 | a: TensorLike, b: TensorLike, axes: int | Sequence[Sequence[int]] = 2 |
3340 | 3213 | ) -> TensorVariable: |
|
0 commit comments