Skip to content

Commit 1f371b7

Browse files
committed
Allow einsum to work with inputs of unknown static shape
1 parent 8250e32 commit 1f371b7

File tree

2 files changed

+104
-20
lines changed

2 files changed

+104
-20
lines changed

pytensor/tensor/einsum.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import itertools
23
from collections.abc import Sequence
34
from functools import partial, reduce
45
from itertools import pairwise
@@ -9,6 +10,8 @@
910
normalize_axis_index,
1011
normalize_axis_tuple,
1112
)
13+
from opt_einsum.helpers import find_contraction
14+
from opt_einsum.parser import parse_einsum_input
1215

1316
from pytensor.compile.builders import OpFromGraph
1417
from pytensor.tensor import TensorLike
@@ -33,14 +36,13 @@ class Einsum(OpFromGraph):
3336
Wrapper Op for Einsum graphs
3437
"""
3538

36-
__props__ = ("subscripts", "optimize")
39+
__props__ = ("subscripts", "path", "optimized")
3740

38-
def __init__(
39-
self, *args, subscripts: str, optimize: str | None = "optimal", **kwargs
40-
):
41+
def __init__(self, *args, subscripts: str, path: str, optimized: bool, **kwargs):
4142
self.subscripts = subscripts
42-
self.optimize = optimize
43-
super().__init__(*args, **kwargs)
43+
self.path = path
44+
self.optimized = optimized
45+
super().__init__(*args, **kwargs, strict=True)
4446

4547

4648
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
@@ -141,6 +143,57 @@ def _general_dot(
141143
return cast(TensorVariable, out)
142144

143145

146+
PATH = tuple[tuple[int] | tuple[int, int]]
147+
148+
149+
def contraction_list_from_path(
150+
subscripts: str, operands: Sequence[TensorLike], path: PATH
151+
):
152+
"""TODO Docstrings
153+
154+
Code adapted from einsum_opt
155+
"""
156+
fake_operands = [
157+
np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands
158+
]
159+
input_subscripts, output_subscript, operands = parse_einsum_input(
160+
(subscripts, *fake_operands)
161+
)
162+
163+
# Build a few useful list and sets
164+
input_list = input_subscripts.split(",")
165+
input_sets = [set(x) for x in input_list]
166+
output_set = set(output_subscript)
167+
168+
# Build contraction tuple (positions, gemm, einsum_str, remaining)
169+
contraction_list = []
170+
for cnum, contract_inds in enumerate(path):
171+
# Make sure we remove inds from right to left
172+
contract_inds = tuple(sorted(contract_inds, reverse=True))
173+
174+
contract_tuple = find_contraction(contract_inds, input_sets, output_set)
175+
out_inds, input_sets, idx_removed, idx_contract = contract_tuple
176+
177+
tmp_inputs = [input_list.pop(x) for x in contract_inds]
178+
179+
# Last contraction
180+
if (cnum - len(path)) == -1:
181+
idx_result = output_subscript
182+
else:
183+
# use tensordot order to minimize transpositions
184+
all_input_inds = "".join(tmp_inputs)
185+
idx_result = "".join(sorted(out_inds, key=all_input_inds.find))
186+
187+
input_list.append(idx_result)
188+
einsum_str = ",".join(tmp_inputs) + "->" + idx_result
189+
190+
# We only need the first three inputs to build the forward graph
191+
contraction = (contract_inds, idx_removed, einsum_str, None, None)
192+
contraction_list.append(contraction)
193+
194+
return contraction_list
195+
196+
144197
def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
145198
"""
146199
Multiplication and summation of tensors using the Einstein summation convention.
@@ -167,18 +220,35 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
167220
# TODO: Do we need this as dependency?
168221
from opt_einsum import contract_path
169222

170-
operands = cast(tuple[TensorVariable], tuple(map(as_tensor, operands)))
223+
operands = [as_tensor(operand) for operand in operands]
171224
shapes = [operand.type.shape for operand in operands]
172225

173-
# TODE: Do fast path at creation time, and optimize only in fast_run
174-
_, contraction_list = contract_path(
175-
subscripts,
176-
*shapes,
177-
einsum_call=True,
178-
use_blas=True,
179-
optimize="optimal",
180-
shapes=True,
181-
)
226+
if None in itertools.chain.from_iterable(shapes):
227+
# We mark optimized = False, even in cases where there is no ordering optimization to be done
228+
# because the inner graph may have to accommodate dynamic shapes.
229+
# If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
230+
if len(operands) == 1:
231+
path = [(0,)]
232+
else:
233+
# Create default path of repeating (1,0) that executes left to right cyclically
234+
# with intermediate outputs being pushed to the end of the stack
235+
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
236+
path = [(1, 0) for i in range(len(operands) - 1)]
237+
contraction_list = contraction_list_from_path(subscripts, operands, path)
238+
optimized = (
239+
len(operands) <= 2
240+
) # If there are only 1 or 2 operands, there is no optimization to be done?
241+
else:
242+
_, contraction_list = contract_path(
243+
subscripts,
244+
*shapes,
245+
einsum_call=True,
246+
use_blas=True,
247+
optimize="optimal",
248+
shapes=True,
249+
)
250+
path = [contraction[0] for contraction in contraction_list]
251+
optimized = True
182252

183253
def sum_uniques(
184254
operand: TensorVariable, names: str, uniques: list[str]
@@ -245,6 +315,7 @@ def sum_repeats(
245315
lhs, rhs = map(einsum_operands.pop, operand_indices)
246316
lhs_names, rhs_names = input_names
247317

318+
# TODO: Do this as well?
248319
# handle cases where one side of a contracting or batch dimension is 1
249320
# but its counterpart is not.
250321
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
@@ -322,6 +393,10 @@ def sum_repeats(
322393
axes=(lhs_cont, rhs_cont),
323394
batch_axes=(lhs_batch, rhs_batch),
324395
)
396+
else:
397+
raise ValueError(
398+
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}"
399+
)
325400

326401
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
327402
assert len(names) == len(result_names) == len(set(names))
@@ -337,5 +412,7 @@ def sum_repeats(
337412
subscripts=subscripts,
338413
inputs=list(operands),
339414
outputs=[einsum_result],
415+
path=tuple(path),
416+
optimized=optimized,
340417
)(*operands)
341418
return cast(TensorVariable, out)

tests/tensor/test_einsum.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_general_dot():
7272
)
7373

7474

75+
@pytest.mark.parametrize("static_shape_known", [True, False])
7576
@pytest.mark.parametrize(
7677
"signature",
7778
[
@@ -95,16 +96,23 @@ def test_general_dot():
9596
"oij,imj,mjkn,lnk,plk->op",
9697
],
9798
)
98-
def test_parse_einsum_input(signature):
99+
def test_einsum_signatures(static_shape_known, signature):
99100
letters_to_dims = dict(zip("ijklmnop", [2, 3, 5, 7, 11, 13, 17, 19], strict=True))
100101

101102
inputs = signature.split("->")[0].split(",")
102103

103104
shapes = [tuple(letters_to_dims[letter] for letter in inp) for inp in inputs]
105+
if static_shape_known:
106+
static_shapes = shapes
107+
else:
108+
static_shapes = [[None] * len(shape) for shape in shapes]
109+
104110
operands = [
105-
pt.tensor(name, shape=shape) for name, shape in zip(ascii_lowercase, shapes)
111+
pt.tensor(name, shape=static_shape)
112+
for name, static_shape in zip(ascii_lowercase, static_shapes)
106113
]
107114
out = pt.einsum(signature, *operands)
115+
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2
108116

109117
rng = np.random.default_rng(37)
110118
test_values = [rng.normal(size=shape) for shape in shapes]
@@ -113,9 +121,8 @@ def test_parse_einsum_input(signature):
113121
fn = function(operands, out)
114122
pt_out = fn(*test_values)
115123

116-
# print()
117124
# import pytensor
118-
# pytensor.dprint(fn, print_type=True)
125+
# print(); pytensor.dprint(fn, print_type=True)
119126

120127
# assert out.type.shape == np_out.shape # Reshape operations lose static shape
121128
np.testing.assert_allclose(pt_out, np_out)

0 commit comments

Comments
 (0)