2121from pytensor .tensor .blockwise import Blockwise
2222from pytensor .tensor .nlinalg import kron , matrix_dot
2323from pytensor .tensor .shape import reshape
24- from pytensor .tensor .type import ivector , matrix , tensor , vector
24+ from pytensor .tensor .type import matrix , tensor , vector
2525from pytensor .tensor .variable import TensorVariable
2626
2727
@@ -585,10 +585,14 @@ def lu(
585585
586586
587587class PivotToPermutations (Op ):
588- itypes = [ivector ]
589- otypes = [ivector ]
588+ __props__ = ("inverse" , "inplace" )
590589
591- __props__ = ()
590+ def __init__ (self , inverse = True , inplace = False ):
591+ self .inverse = inverse
592+ self .inplace = inplace
593+ self .destroy_map = {}
594+ if self .inplace :
595+ self .destroy_map = {0 : [0 ]}
592596
593597 def make_node (self , pivots ):
594598 pivots = as_tensor_variable (pivots )
@@ -598,15 +602,29 @@ def make_node(self, pivots):
598602
599603 return Apply (self , [pivots ], [permutations ])
600604
605+ def inplace_on_inputs (self , allowed_inplace_inputs : list [int ]) -> "Op" :
606+ if 0 in allowed_inplace_inputs :
607+ new_props = self ._props_dict () # type: ignore
608+ new_props ["inplace" ] = True
609+ return type (self )(** new_props )
610+ else :
611+ return self
612+
601613 def perform (self , node , inputs , outputs ):
602614 [p ] = inputs
603615 p_inv = np .arange (len (p )).astype (p .dtype )
604616 for i in range (len (p )):
605617 p_inv [i ], p_inv [p [i ]] = p_inv [p [i ]], p_inv [i ]
606- outputs [0 ][0 ] = p_inv
618+
619+ if self .inverse :
620+ outputs [0 ][0 ] = p_inv
621+
622+ outputs [0 ][0 ] = np .argsort (p_inv )
607623
608624
609- _pivot_to_permutation = PivotToPermutations ()
625+ def pivot_to_permutation (p : TensorLike , inverse = False ) -> Variable :
626+ p = pt .as_tensor_variable (p )
627+ return PivotToPermutations (inverse = inverse )(p )
610628
611629
612630class LUFactor (Op ):
@@ -631,7 +649,7 @@ def make_node(self, A):
631649 )
632650
633651 LU = matrix (shape = A .type .shape , dtype = A .type .dtype )
634- pivots_or_permutations = vector (shape = (A .type .shape [0 ],), dtype = "int32 " )
652+ pivots_or_permutations = vector (shape = (A .type .shape [0 ],), dtype = "int64 " )
635653
636654 return Apply (self , [A ], [LU , pivots_or_permutations ])
637655
@@ -681,8 +699,7 @@ def L_op(self, inputs, outputs, output_gradients):
681699 U = cast (TensorVariable , ptb .triu (LU ))
682700
683701 if not self .permutation_indices :
684- p_indices_inv = _pivot_to_permutation (cast (TensorVariable , p_indices ))
685- p_indices = pt .argsort (p_indices_inv )
702+ p_indices = pivot_to_permutation (p_indices , inverse = False )
686703
687704 # Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
688705 L_bar = ptb .tril (LU_bar , k = - 1 )
@@ -784,7 +801,7 @@ def lu_solve(
784801 LU , pivots = LU_and_pivots
785802
786803 LU , pivots , b = map (pt .as_tensor_variable , [LU , pivots , b ])
787- inv_permutation = _pivot_to_permutation (pivots )
804+ inv_permutation = pivot_to_permutation (pivots , inverse = True )
788805
789806 x = b [inv_permutation ] if not trans else b
790807
0 commit comments