@@ -100,43 +100,40 @@ def find_solve_clients(var, assume_a):
100100 elif isinstance (cl .op , DimShuffle ) and cl .op .is_left_expand_dims :
101101 # If it's a left expand_dims, recurse on the output
102102 clients .extend (find_solve_clients (cl .outputs [0 ], assume_a ))
103-
104103 return clients
105104
106105 assume_a = node .op .core_op .assume_a
107106
108107 if assume_a not in allowed_assume_a :
109108 return None
110109
111- root_A , root_A_transposed = get_root_A (node .inputs [0 ])
110+ A , _ = get_root_A (node .inputs [0 ])
112111
113112 # Find Solve using A (or left expand_dims of A)
114113 # TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
115114 # that to the A_decomp outputs
116- root_A_solve_clients_and_transpose = [
117- (client , False ) for client in find_solve_clients (root_A , assume_a )
115+ A_solve_clients_and_transpose = [
116+ (client , False ) for client in find_solve_clients (A , assume_a )
118117 ]
119118
120119 # Find Solves using A.T
121- for cl , _ in fgraph .clients [root_A ]:
120+ for cl , _ in fgraph .clients [A ]:
122121 if isinstance (cl .op , DimShuffle ) and is_matrix_transpose (cl .out ):
123122 A_T = cl .out
124- root_A_solve_clients_and_transpose .extend (
123+ A_solve_clients_and_transpose .extend (
125124 (client , True ) for client in find_solve_clients (A_T , assume_a )
126125 )
127126
128- if not eager and len (root_A_solve_clients_and_transpose ) == 1 :
127+ if not eager and len (A_solve_clients_and_transpose ) == 1 :
129128 # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
130129 # That's a "reuse" inside the inner vectorized loop
131130 batch_ndim = node .op .batch_ndim (node )
132- (client , _ ) = root_A_solve_clients_and_transpose [0 ]
133-
134- A , b = client .inputs
135-
131+ (client , _ ) = A_solve_clients_and_transpose [0 ]
132+ original_A , b = client .inputs
136133 if not any (
137134 a_bcast and not b_bcast
138135 for a_bcast , b_bcast in zip (
139- A .type .broadcastable [:batch_ndim ],
136+ original_A .type .broadcastable [:batch_ndim ],
140137 b .type .broadcastable [:batch_ndim ],
141138 strict = True ,
142139 )
@@ -145,27 +142,19 @@ def find_solve_clients(var, assume_a):
145142
146143 # If any Op had check_finite=True, we also do it for the LU decomposition
147144 check_finite_decomp = False
148- for client , _ in root_A_solve_clients_and_transpose :
145+ for client , _ in A_solve_clients_and_transpose :
149146 if client .op .core_op .check_finite :
150147 check_finite_decomp = True
151148 break
152149
153- (first_solve , transposed ) = root_A_solve_clients_and_transpose [0 ]
154- lower = first_solve .op .core_op .lower
155- if transposed :
156- lower = not lower
157-
150+ lower = node .op .core_op .lower
158151 A_decomp = decompose_A (
159- root_A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
152+ A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
160153 )
161154
162155 replacements = {}
163- for client , transposed in root_A_solve_clients_and_transpose :
156+ for client , transposed in A_solve_clients_and_transpose :
164157 _ , b = client .inputs
165- lower = client .op .core_op .lower
166- if transposed :
167- lower = not lower
168-
169158 new_x = solve_decomposed_system (
170159 A_decomp ,
171160 b ,
0 commit comments