|
11 | 11 | from pytensor.npy_2_compat import ( |
12 | 12 | _find_contraction, |
13 | 13 | _parse_einsum_input, |
14 | | - normalize_axis_index, |
15 | 14 | normalize_axis_tuple, |
16 | 15 | ) |
17 | 16 | from pytensor.tensor import TensorLike |
18 | 17 | from pytensor.tensor.basic import ( |
19 | | - arange, |
20 | 18 | as_tensor, |
21 | 19 | expand_dims, |
22 | | - get_vector_length, |
| 20 | + iota, |
23 | 21 | moveaxis, |
24 | 22 | stack, |
25 | 23 | transpose, |
|
28 | 26 | from pytensor.tensor.extra_ops import broadcast_to |
29 | 27 | from pytensor.tensor.functional import vectorize |
30 | 28 | from pytensor.tensor.math import and_, eq, tensordot |
31 | | -from pytensor.tensor.shape import shape_padright |
32 | 29 | from pytensor.tensor.variable import TensorVariable |
33 | 30 |
|
34 | 31 |
|
@@ -63,64 +60,6 @@ def __str__(self): |
63 | 60 | return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}" |
64 | 61 |
|
65 | 62 |
|
66 | | -def _iota(shape: TensorVariable, axis: int) -> TensorVariable: |
67 | | - """ |
68 | | - Create an array with values increasing along the specified axis. |
69 | | -
|
70 | | - Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers |
71 | | - increasing along the specified axis. |
72 | | -
|
73 | | - Parameters |
74 | | - ---------- |
75 | | - shape: TensorVariable |
76 | | - The shape of the array to be created. |
77 | | - axis: int |
78 | | - The axis along which to fill the array with increasing values. |
79 | | -
|
80 | | - Returns |
81 | | - ------- |
82 | | - TensorVariable |
83 | | - An array with values increasing along the specified axis. |
84 | | -
|
85 | | - Examples |
86 | | - -------- |
87 | | - In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``: |
88 | | -
|
89 | | - .. testcode:: |
90 | | -
|
91 | | - import pytensor.tensor as pt |
92 | | - from pytensor.tensor.einsum import _iota |
93 | | -
|
94 | | - shape = pt.as_tensor((5,)) |
95 | | - print(_iota(shape, 0).eval()) |
96 | | -
|
97 | | - .. testoutput:: |
98 | | -
|
99 | | - [0 1 2 3 4] |
100 | | -
|
101 | | - In higher dimensions, it will look like many concatenated `arange`: |
102 | | -
|
103 | | - .. testcode:: |
104 | | -
|
105 | | - shape = pt.as_tensor((5, 5)) |
106 | | - print(_iota(shape, 1).eval()) |
107 | | -
|
108 | | - .. testoutput:: |
109 | | -
|
110 | | - [[0 1 2 3 4] |
111 | | - [0 1 2 3 4] |
112 | | - [0 1 2 3 4] |
113 | | - [0 1 2 3 4] |
114 | | - [0 1 2 3 4]] |
115 | | -
|
116 | | - Setting ``axis=0`` above would result in the transpose of the output. |
117 | | - """ |
118 | | - len_shape = get_vector_length(shape) |
119 | | - axis = normalize_axis_index(axis, len_shape) |
120 | | - values = arange(shape[axis]) |
121 | | - return broadcast_to(shape_padright(values, len_shape - axis - 1), shape) |
122 | | - |
123 | | - |
124 | 63 | def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable: |
125 | 64 | """ |
126 | 65 | Create a Kroncker delta tensor. |
@@ -201,7 +140,7 @@ def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable: |
201 | 140 | if len(axes) == 1: |
202 | 141 | raise ValueError("Need at least two axes to create a delta tensor") |
203 | 142 | base_shape = stack([shape[axis] for axis in axes]) |
204 | | - iotas = [_iota(base_shape, i) for i in range(len(axes))] |
| 143 | + iotas = [iota(base_shape, i) for i in range(len(axes))] |
205 | 144 | eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)] |
206 | 145 | result = reduce(and_, eyes) |
207 | 146 | non_axes = [i for i in range(len(tuple(shape))) if i not in axes] |
|
0 commit comments