@@ -410,6 +410,12 @@ def _contraction_list_from_path(
410410 return contraction_list
411411
412412
413+ def _right_to_left_path (n : int ) -> tuple [tuple [int , int ], ...]:
414+ # Create a right to left contraction path
415+ # if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
416+ return tuple (pairwise (reversed (range (n ))))
417+
418+
413419def einsum (subscripts : str , * operands : "TensorLike" , optimize = None ) -> TensorVariable :
414420 """
415421 Multiplication and summation of tensors using the Einstein summation convention.
@@ -563,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
563569 else :
564570 # By default, we try right to left because we assume that most graphs
565571 # have a lower dimensional rightmost operand
566- path = tuple ( pairwise ( reversed ( range ( len (tensor_operands ))) ))
572+ path = _right_to_left_path ( len (tensor_operands ))
567573 contraction_list = _contraction_list_from_path (
568574 subscripts , tensor_operands , path
569575 )
@@ -582,6 +588,15 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
582588 optimize = "optimal" ,
583589 ) # type: ignore
584590 path = tuple (contraction [0 ] for contraction in contraction_list )
591+
592+ if len (path ) == 1 and len (path [0 ]) > 2 :
593+ # When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
594+ # pairwise reductions, which our implementation below demands.
595+ path = _right_to_left_path (len (tensor_operands ))
596+ contraction_list = _contraction_list_from_path (
597+ subscripts , tensor_operands , path
598+ )
599+
585600 optimized = True
586601
587602 def removechars (s , chars ):
@@ -744,7 +759,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names):
744759 )
745760 else :
746761 raise ValueError (
747- f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} "
762+ f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} , { path = } . "
748763 )
749764
750765 # the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
0 commit comments