@@ -585,41 +585,30 @@ def lu(
585585
586586
587587class PivotToPermutations (Op ):
588- __props__ = ("inverse" , "inplace" )
588+ __props__ = ("inverse" ,)
589589
590- def __init__ (self , inverse = True , inplace = False ):
590+ def __init__ (self , inverse = True ):
591591 self .inverse = inverse
592- self .inplace = inplace
593- self .destroy_map = {}
594- if self .inplace :
595- self .destroy_map = {0 : [0 ]}
596592
597593 def make_node (self , pivots ):
598594 pivots = as_tensor_variable (pivots )
599595 if pivots .ndim != 1 :
600596 raise ValueError ("PivotToPermutations only works on 1-D inputs" )
601- permutations = pivots .type ()
602597
598+ permutations = pivots .type .clone (dtype = "int64" )()
603599 return Apply (self , [pivots ], [permutations ])
604600
605- def inplace_on_inputs (self , allowed_inplace_inputs : list [int ]) -> "Op" :
606- if 0 in allowed_inplace_inputs :
607- new_props = self ._props_dict () # type: ignore
608- new_props ["inplace" ] = True
609- return type (self )(** new_props )
610- else :
611- return self
612-
613601 def perform (self , node , inputs , outputs ):
614- [p ] = inputs
615- p_inv = np .arange (len (p )).astype (p .dtype )
616- for i in range (len (p )):
617- p_inv [i ], p_inv [p [i ]] = p_inv [p [i ]], p_inv [i ]
602+ [pivots ] = inputs
603+ p_inv = np .arange (len (pivots ), dtype = pivots .dtype )
604+
605+ for i in range (len (pivots )):
606+ p_inv [i ], p_inv [pivots [i ]] = p_inv [pivots [i ]], p_inv [i ]
618607
619608 if self .inverse :
620609 outputs [0 ][0 ] = p_inv
621-
622- outputs [0 ][0 ] = np .argsort (p_inv )
610+ else :
611+ outputs [0 ][0 ] = np .argsort (p_inv )
623612
624613
625614def pivot_to_permutation (p : TensorLike , inverse = False ) -> Variable :
@@ -629,14 +618,14 @@ def pivot_to_permutation(p: TensorLike, inverse=False) -> Variable:
629618
630619class LUFactor (Op ):
631620 __props__ = ("overwrite_a" , "check_finite" , "permutation_indices" )
621+ gufunc_signature = "(m,m)->(m,m),(m)"
632622
633623 def __init__ (
634624 self , * , overwrite_a = False , check_finite = True , permutation_indices = False
635625 ):
636626 self .overwrite_a = overwrite_a
637627 self .check_finite = check_finite
638628 self .permutation_indices = permutation_indices
639- self .gufunc_signature = "(m,m)->(m,m),(m)"
640629
641630 if self .overwrite_a :
642631 self .destroy_map = {1 : [0 ]}
0 commit comments