@@ -410,6 +410,12 @@ def _contraction_list_from_path(
410410 return contraction_list
411411
412412
413+ def _right_to_left_path (n : int ) -> tuple [tuple [int , int ], ...]:
414+ # Create a right to left contraction path
415+ # if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
416+ return tuple (pairwise (reversed (range (n ))))
417+
418+
413419def einsum (subscripts : str , * operands : "TensorLike" , optimize = None ) -> TensorVariable :
414420 """
415421 Multiplication and summation of tensors using the Einstein summation convention.
@@ -546,8 +552,6 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
546552 "If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. "
547553 )
548554
549- # TODO: Is this doing something clever about unknown shapes?
550- # contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
551555 tensor_operands = [as_tensor (operand ) for operand in operands ]
552556 shapes = [operand .type .shape for operand in tensor_operands ]
553557
@@ -565,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
565569 else :
566570 # By default, we try right to left because we assume that most graphs
567571 # have a lower dimensional rightmost operand
568- path = tuple ( pairwise ( reversed ( range ( len (tensor_operands ))) ))
572+ path = _right_to_left_path ( len (tensor_operands ))
569573 contraction_list = _contraction_list_from_path (
570574 subscripts , tensor_operands , path
571575 )
@@ -584,6 +588,15 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
584588 optimize = "optimal" ,
585589 ) # type: ignore
586590 path = tuple (contraction [0 ] for contraction in contraction_list )
591+
592+ if len (path ) == 1 and len (path [0 ]) > 2 :
593+ # When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
594+ # pairwise reductions, which our implementation below demands.
595+ path = _right_to_left_path (len (tensor_operands ))
596+ contraction_list = _contraction_list_from_path (
597+ subscripts , tensor_operands , path
598+ )
599+
587600 optimized = True
588601
589602 def removechars (s , chars ):
@@ -745,6 +758,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names):
745758 batch_axes = (lhs_batch , rhs_batch ),
746759 )
747760 else :
761+ print (operand_indices , contracted_names , einstr )
748762 raise ValueError (
749763 f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} "
750764 )
0 commit comments