Skip to content

Commit 21a158d

Browse files
committed
Move iota from einsum.py to basic.py
1 parent a6444c7 commit 21a158d

File tree

2 files changed

+65
-65
lines changed

2 files changed

+65
-65
lines changed

pytensor/tensor/basic.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
get_vector_length,
4444
)
4545
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
46-
from pytensor.tensor.einsum import _iota
4746
from pytensor.tensor.elemwise import (
4847
DimShuffle,
4948
Elemwise,
@@ -1061,6 +1060,65 @@ def flatnonzero(a):
10611060
return nonzero(_a.flatten(), return_matrix=False)[0]
10621061

10631062

1063+
def iota(shape: TensorVariable, axis: int) -> TensorVariable:
1064+
"""
1065+
Create an array with values increasing along the specified axis.
1066+
1067+
Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers
1068+
increasing along the specified axis.
1069+
1070+
Parameters
1071+
----------
1072+
shape: TensorVariable
1073+
The shape of the array to be created.
1074+
axis: int
1075+
The axis along which to fill the array with increasing values.
1076+
1077+
Returns
1078+
-------
1079+
TensorVariable
1080+
An array with values increasing along the specified axis.
1081+
1082+
Examples
1083+
--------
1084+
In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``:
1085+
1086+
.. testcode::
1087+
1088+
import pytensor.tensor as pt
1089+
1090+
shape = pt.as_tensor((5,))
1091+
print(pt.basic.iota(shape, 0).eval())
1092+
1093+
.. testoutput::
1094+
1095+
[0 1 2 3 4]
1096+
1097+
In higher dimensions, it will look like many concatenated `arange`:
1098+
1099+
.. testcode::
1100+
1101+
shape = pt.as_tensor((5, 5))
1102+
print(pt.basic.iota(shape, 1).eval())
1103+
1104+
.. testoutput::
1105+
1106+
[[0 1 2 3 4]
1107+
[0 1 2 3 4]
1108+
[0 1 2 3 4]
1109+
[0 1 2 3 4]
1110+
[0 1 2 3 4]]
1111+
1112+
Setting ``axis=0`` above would result in the transpose of the output.
1113+
"""
1114+
len_shape = get_vector_length(shape)
1115+
axis = normalize_axis_index(axis, len_shape)
1116+
values = arange(shape[axis])
1117+
return pytensor.tensor.extra_ops.broadcast_to(
1118+
shape_padright(values, len_shape - axis - 1), shape
1119+
)
1120+
1121+
10641122
def nonzero_values(a):
10651123
"""Return a vector of non-zero elements contained in the input array.
10661124
@@ -1128,7 +1186,10 @@ def tri(N, M=None, k=0, dtype=None):
11281186
dtype = config.floatX
11291187
if M is None:
11301188
M = N
1131-
output = ((_iota(M) + k) > _iota(N)).astype(int)
1189+
output = ((iota(as_tensor((N, 1)), 0) + k + 1) > iota(as_tensor((1, M)), 1)).astype(
1190+
int
1191+
)
1192+
N = as_tensor_variable(N)
11321193
return Tri(inputs=[N], outputs=[output], M=M, k=k, dtype=dtype)(N)
11331194

11341195

pytensor/tensor/einsum.py

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111
from pytensor.npy_2_compat import (
1212
_find_contraction,
1313
_parse_einsum_input,
14-
normalize_axis_index,
1514
normalize_axis_tuple,
1615
)
1716
from pytensor.tensor import TensorLike
1817
from pytensor.tensor.basic import (
19-
arange,
2018
as_tensor,
2119
expand_dims,
22-
get_vector_length,
20+
iota,
2321
moveaxis,
2422
stack,
2523
transpose,
@@ -28,7 +26,6 @@
2826
from pytensor.tensor.extra_ops import broadcast_to
2927
from pytensor.tensor.functional import vectorize
3028
from pytensor.tensor.math import and_, eq, tensordot
31-
from pytensor.tensor.shape import shape_padright
3229
from pytensor.tensor.variable import TensorVariable
3330

3431

@@ -63,64 +60,6 @@ def __str__(self):
6360
return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}"
6461

6562

66-
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
67-
"""
68-
Create an array with values increasing along the specified axis.
69-
70-
Iota is a multidimensional generalization of the `arange` function. The returned array is filled with whole numbers
71-
increasing along the specified axis.
72-
73-
Parameters
74-
----------
75-
shape: TensorVariable
76-
The shape of the array to be created.
77-
axis: int
78-
The axis along which to fill the array with increasing values.
79-
80-
Returns
81-
-------
82-
TensorVariable
83-
An array with values increasing along the specified axis.
84-
85-
Examples
86-
--------
87-
In the simplest case where ``shape`` is 1d, the output will be equivalent to ``pt.arange``:
88-
89-
.. testcode::
90-
91-
import pytensor.tensor as pt
92-
from pytensor.tensor.einsum import _iota
93-
94-
shape = pt.as_tensor((5,))
95-
print(_iota(shape, 0).eval())
96-
97-
.. testoutput::
98-
99-
[0 1 2 3 4]
100-
101-
In higher dimensions, it will look like many concatenated `arange`:
102-
103-
.. testcode::
104-
105-
shape = pt.as_tensor((5, 5))
106-
print(_iota(shape, 1).eval())
107-
108-
.. testoutput::
109-
110-
[[0 1 2 3 4]
111-
[0 1 2 3 4]
112-
[0 1 2 3 4]
113-
[0 1 2 3 4]
114-
[0 1 2 3 4]]
115-
116-
Setting ``axis=0`` above would result in the transpose of the output.
117-
"""
118-
len_shape = get_vector_length(shape)
119-
axis = normalize_axis_index(axis, len_shape)
120-
values = arange(shape[axis])
121-
return broadcast_to(shape_padright(values, len_shape - axis - 1), shape)
122-
123-
12463
def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable:
12564
"""
12665
Create a Kroncker delta tensor.
@@ -201,7 +140,7 @@ def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable:
201140
if len(axes) == 1:
202141
raise ValueError("Need at least two axes to create a delta tensor")
203142
base_shape = stack([shape[axis] for axis in axes])
204-
iotas = [_iota(base_shape, i) for i in range(len(axes))]
143+
iotas = [iota(base_shape, i) for i in range(len(axes))]
205144
eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)]
206145
result = reduce(and_, eyes)
207146
non_axes = [i for i in range(len(tuple(shape))) if i not in axes]

0 commit comments

Comments
 (0)