1515from pytensor .tensor .elemwise import DimShuffle
1616from pytensor .tensor .rewriting .basic import register_specialize
1717from pytensor .tensor .rewriting .blockwise import blockwise_of
18- from pytensor .tensor .rewriting .linalg import is_matrix_transpose
1918from pytensor .tensor .slinalg import Solve , cho_solve , cholesky , lu_factor , lu_solve
2019from pytensor .tensor .variable import TensorVariable
2120
@@ -74,28 +73,26 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
7473 # the root variable is the pre-DimShuffled input.
7574 # Otherwise, `a` is considered the root variable.
7675 # We also return whether the root `a` is transposed.
76+ root_a = a
7777 transposed = False
78- if a .owner is not None and isinstance ( a . owner . op , DimShuffle ) :
79- if a . owner . op . is_left_expand_dims :
80- [ a ] = a . owner . inputs
81- elif is_matrix_transpose ( a ):
82- [ a ] = a . owner . inputs
83- transposed = True
84- return a , transposed
78+ match a .owner_op_and_inputs :
79+ case ( DimShuffle ( is_left_expand_dims = True ), root_a ): # type: ignore[misc]
80+ transposed = False
81+ case ( DimShuffle ( is_left_expanded_matrix_transpose = True ), root_a ): # type: ignore[misc]
82+ transposed = True # type: ignore[unreachable]
83+
84+ return root_a , transposed
8585
8686 def find_solve_clients (var , assume_a ):
8787 clients = []
8888 for cl , idx in fgraph .clients [var ]:
89- if (
90- idx == 0
91- and isinstance (cl .op , Blockwise )
92- and isinstance (cl .op .core_op , Solve )
93- and (cl .op .core_op .assume_a == assume_a )
94- ):
95- clients .append (cl )
96- elif isinstance (cl .op , DimShuffle ) and cl .op .is_left_expand_dims :
97- # If it's a left expand_dims, recurse on the output
98- clients .extend (find_solve_clients (cl .outputs [0 ], assume_a ))
89+ match (idx , cl .op , * cl .outputs ):
90+ case (0 , Blockwise (Solve (assume_a = assume_a_var )), * _) if (
91+ assume_a_var == assume_a
92+ ):
93+ clients .append (cl )
94+ case (0 , DimShuffle (is_left_expand_dims = True ), cl_out ):
95+ clients .extend (find_solve_clients (cl_out , assume_a ))
9996 return clients
10097
10198 assume_a = node .op .core_op .assume_a
@@ -114,11 +111,11 @@ def find_solve_clients(var, assume_a):
114111
115112 # Find Solves using A.T
116113 for cl , _ in fgraph .clients [A ]:
117- if isinstance (cl .op , DimShuffle ) and is_matrix_transpose ( cl .out ):
118- A_T = cl . out
119- A_solve_clients_and_transpose .extend (
120- (client , True ) for client in find_solve_clients (A_T , assume_a )
121- )
114+ match (cl .op , * cl .outputs ):
115+ case ( DimShuffle ( is_left_expanded_matrix_transpose = True ), A_T ):
116+ A_solve_clients_and_transpose .extend (
117+ (client , True ) for client in find_solve_clients (A_T , assume_a )
118+ )
122119
123120 if not eager and len (A_solve_clients_and_transpose ) == 1 :
124121 # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
@@ -171,34 +168,34 @@ def _scan_split_non_sequence_decomposition_and_solve(
171168 changed = False
172169 while True :
173170 for inner_node in new_scan_fgraph .toposort ():
174- if (
175- isinstance (inner_node .op , Blockwise )
176- and isinstance (inner_node .op .core_op , Solve )
177- and inner_node .op .core_op .assume_a in allowed_assume_a
178- ):
179- A , _b = inner_node .inputs
180- if all (
181- (isinstance (root_inp , Constant ) or (root_inp in non_sequences ))
182- for root_inp in graph_inputs ([A ])
171+ match (inner_node .op , * inner_node .inputs ):
172+ case (Blockwise (Solve (assume_a = assume_a_var )), A , _b ) if (
173+ assume_a_var in allowed_assume_a
183174 ):
184- if new_scan_fgraph is scan_op .fgraph :
185- # Clone the first time to avoid mutating the original fgraph
186- new_scan_fgraph , equiv = new_scan_fgraph .clone_get_equiv ()
187- non_sequences = {equiv [non_seq ] for non_seq in non_sequences }
188- inner_node = equiv [inner_node ] # type: ignore
189-
190- replace_dict = _split_decomp_and_solve_steps (
191- new_scan_fgraph ,
192- inner_node ,
193- eager = True ,
194- allowed_assume_a = allowed_assume_a ,
195- )
196- assert isinstance (replace_dict , dict ) and len (replace_dict ) > 0 , (
197- "Rewrite failed"
198- )
199- new_scan_fgraph .replace_all (replace_dict .items ())
200- changed = True
201- break # Break to start over with a fresh toposort
175+ if all (
176+ (isinstance (root_inp , Constant ) or (root_inp in non_sequences ))
177+ for root_inp in graph_inputs ([A ])
178+ ):
179+ if new_scan_fgraph is scan_op .fgraph :
180+ # Clone the first time to avoid mutating the original fgraph
181+ new_scan_fgraph , equiv = new_scan_fgraph .clone_get_equiv ()
182+ non_sequences = {
183+ equiv [non_seq ] for non_seq in non_sequences
184+ }
185+ inner_node = equiv [inner_node ] # type: ignore
186+
187+ replace_dict = _split_decomp_and_solve_steps (
188+ new_scan_fgraph ,
189+ inner_node ,
190+ eager = True ,
191+ allowed_assume_a = allowed_assume_a ,
192+ )
193+ assert (
194+ isinstance (replace_dict , dict ) and len (replace_dict ) > 0
195+ ), "Rewrite failed"
196+ new_scan_fgraph .replace_all (replace_dict .items ())
197+ changed = True
198+ break # Break to start over with a fresh toposort
202199 else : # no_break
203200 break # Nothing else changed
204201
0 commit comments