1
1
import collections
2
+ import itertools
2
3
from collections .abc import Sequence
3
4
from functools import partial , reduce
4
5
from itertools import pairwise
9
10
normalize_axis_index ,
10
11
normalize_axis_tuple ,
11
12
)
13
+ from opt_einsum .helpers import find_contraction
14
+ from opt_einsum .parser import parse_einsum_input
12
15
13
16
from pytensor .compile .builders import OpFromGraph
14
17
from pytensor .tensor import TensorLike
@@ -33,14 +36,13 @@ class Einsum(OpFromGraph):
33
36
Wrapper Op for Einsum graphs
34
37
"""
35
38
36
- __props__ = ("subscripts" , "optimize " )
39
+ __props__ = ("subscripts" , "path" , "optimized " )
37
40
38
- def __init__ (
39
- self , * args , subscripts : str , optimize : str | None = "optimal" , ** kwargs
40
- ):
41
+ def __init__ (self , * args , subscripts : str , path : str , optimized : bool , ** kwargs ):
41
42
self .subscripts = subscripts
42
- self .optimize = optimize
43
- super ().__init__ (* args , ** kwargs )
43
+ self .path = path
44
+ self .optimized = optimized
45
+ super ().__init__ (* args , ** kwargs , strict = True )
44
46
45
47
46
48
def _iota (shape : TensorVariable , axis : int ) -> TensorVariable :
@@ -141,6 +143,57 @@ def _general_dot(
141
143
return cast (TensorVariable , out )
142
144
143
145
146
+ PATH = tuple [tuple [int ] | tuple [int , int ]]
147
+
148
+
149
+ def contraction_list_from_path (
150
+ subscripts : str , operands : Sequence [TensorLike ], path : PATH
151
+ ):
152
+ """TODO Docstrings
153
+
154
+ Code adapted from einsum_opt
155
+ """
156
+ fake_operands = [
157
+ np .zeros ([1 if dim == 1 else 0 for dim in x .type .shape ]) for x in operands
158
+ ]
159
+ input_subscripts , output_subscript , operands = parse_einsum_input (
160
+ (subscripts , * fake_operands )
161
+ )
162
+
163
+ # Build a few useful list and sets
164
+ input_list = input_subscripts .split ("," )
165
+ input_sets = [set (x ) for x in input_list ]
166
+ output_set = set (output_subscript )
167
+
168
+ # Build contraction tuple (positions, gemm, einsum_str, remaining)
169
+ contraction_list = []
170
+ for cnum , contract_inds in enumerate (path ):
171
+ # Make sure we remove inds from right to left
172
+ contract_inds = tuple (sorted (contract_inds , reverse = True ))
173
+
174
+ contract_tuple = find_contraction (contract_inds , input_sets , output_set )
175
+ out_inds , input_sets , idx_removed , idx_contract = contract_tuple
176
+
177
+ tmp_inputs = [input_list .pop (x ) for x in contract_inds ]
178
+
179
+ # Last contraction
180
+ if (cnum - len (path )) == - 1 :
181
+ idx_result = output_subscript
182
+ else :
183
+ # use tensordot order to minimize transpositions
184
+ all_input_inds = "" .join (tmp_inputs )
185
+ idx_result = "" .join (sorted (out_inds , key = all_input_inds .find ))
186
+
187
+ input_list .append (idx_result )
188
+ einsum_str = "," .join (tmp_inputs ) + "->" + idx_result
189
+
190
+ # We only need the first three inputs to build the forward graph
191
+ contraction = (contract_inds , idx_removed , einsum_str , None , None )
192
+ contraction_list .append (contraction )
193
+
194
+ return contraction_list
195
+
196
+
144
197
def einsum (subscripts : str , * operands : "TensorLike" ) -> TensorVariable :
145
198
"""
146
199
Multiplication and summation of tensors using the Einstein summation convention.
@@ -167,18 +220,35 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
167
220
# TODO: Do we need this as dependency?
168
221
from opt_einsum import contract_path
169
222
170
- operands = cast ( tuple [ TensorVariable ], tuple ( map ( as_tensor , operands )))
223
+ operands = [ as_tensor ( operand ) for operand in operands ]
171
224
shapes = [operand .type .shape for operand in operands ]
172
225
173
- # TODE: Do fast path at creation time, and optimize only in fast_run
174
- _ , contraction_list = contract_path (
175
- subscripts ,
176
- * shapes ,
177
- einsum_call = True ,
178
- use_blas = True ,
179
- optimize = "optimal" ,
180
- shapes = True ,
181
- )
226
+ if None in itertools .chain .from_iterable (shapes ):
227
+ # We mark optimized = False, even in cases where there is no ordering optimization to be done
228
+ # because the inner graph may have to accommodate dynamic shapes.
229
+ # If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
230
+ if len (operands ) == 1 :
231
+ path = [(0 ,)]
232
+ else :
233
+ # Create default path of repeating (1,0) that executes left to right cyclically
234
+ # with intermediate outputs being pushed to the end of the stack
235
+ # We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
236
+ path = [(1 , 0 ) for i in range (len (operands ) - 1 )]
237
+ contraction_list = contraction_list_from_path (subscripts , operands , path )
238
+ optimized = (
239
+ len (operands ) <= 2
240
+ ) # If there are only 1 or 2 operands, there is no optimization to be done?
241
+ else :
242
+ _ , contraction_list = contract_path (
243
+ subscripts ,
244
+ * shapes ,
245
+ einsum_call = True ,
246
+ use_blas = True ,
247
+ optimize = "optimal" ,
248
+ shapes = True ,
249
+ )
250
+ path = [contraction [0 ] for contraction in contraction_list ]
251
+ optimized = True
182
252
183
253
def sum_uniques (
184
254
operand : TensorVariable , names : str , uniques : list [str ]
@@ -245,6 +315,7 @@ def sum_repeats(
245
315
lhs , rhs = map (einsum_operands .pop , operand_indices )
246
316
lhs_names , rhs_names = input_names
247
317
318
+ # TODO: Do this as well?
248
319
# handle cases where one side of a contracting or batch dimension is 1
249
320
# but its counterpart is not.
250
321
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
@@ -322,6 +393,10 @@ def sum_repeats(
322
393
axes = (lhs_cont , rhs_cont ),
323
394
batch_axes = (lhs_batch , rhs_batch ),
324
395
)
396
+ else :
397
+ raise ValueError (
398
+ f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} "
399
+ )
325
400
326
401
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
327
402
assert len (names ) == len (result_names ) == len (set (names ))
@@ -337,5 +412,7 @@ def sum_repeats(
337
412
subscripts = subscripts ,
338
413
inputs = list (operands ),
339
414
outputs = [einsum_result ],
415
+ path = tuple (path ),
416
+ optimized = optimized ,
340
417
)(* operands )
341
418
return cast (TensorVariable , out )
0 commit comments