Skip to content

Commit e265141

Browse files
committed
Simplify linalg rewrites with pattern matching
1 parent 93f7629 commit e265141

File tree

5 files changed

+396
-627
lines changed

5 files changed

+396
-627
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.rewriting.basic import register_specialize
1717
from pytensor.tensor.rewriting.blockwise import blockwise_of
18-
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
1918
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
2019
from pytensor.tensor.variable import TensorVariable
2120

@@ -74,28 +73,26 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
7473
# the root variable is the pre-DimShuffled input.
7574
# Otherwise, `a` is considered the root variable.
7675
# We also return whether the root `a` is transposed.
76+
root_a = a
7777
transposed = False
78-
if a.owner is not None and isinstance(a.owner.op, DimShuffle):
79-
if a.owner.op.is_left_expand_dims:
80-
[a] = a.owner.inputs
81-
elif is_matrix_transpose(a):
82-
[a] = a.owner.inputs
83-
transposed = True
84-
return a, transposed
78+
match a.owner_op_and_inputs:
79+
case (DimShuffle(is_left_expand_dims=True), root_a): # type: ignore[misc]
80+
transposed = False
81+
case (DimShuffle(is_left_expanded_matrix_transpose=True), root_a): # type: ignore[misc]
82+
transposed = True # type: ignore[unreachable]
83+
84+
return root_a, transposed
8585

8686
def find_solve_clients(var, assume_a):
8787
clients = []
8888
for cl, idx in fgraph.clients[var]:
89-
if (
90-
idx == 0
91-
and isinstance(cl.op, Blockwise)
92-
and isinstance(cl.op.core_op, Solve)
93-
and (cl.op.core_op.assume_a == assume_a)
94-
):
95-
clients.append(cl)
96-
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
97-
# If it's a left expand_dims, recurse on the output
98-
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
89+
match (idx, cl.op, *cl.outputs):
90+
case (0, Blockwise(Solve(assume_a=assume_a_var)), *_) if (
91+
assume_a_var == assume_a
92+
):
93+
clients.append(cl)
94+
case (0, DimShuffle(is_left_expand_dims=True), cl_out):
95+
clients.extend(find_solve_clients(cl_out, assume_a))
9996
return clients
10097

10198
assume_a = node.op.core_op.assume_a
@@ -114,11 +111,11 @@ def find_solve_clients(var, assume_a):
114111

115112
# Find Solves using A.T
116113
for cl, _ in fgraph.clients[A]:
117-
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
118-
A_T = cl.out
119-
A_solve_clients_and_transpose.extend(
120-
(client, True) for client in find_solve_clients(A_T, assume_a)
121-
)
114+
match (cl.op, *cl.outputs):
115+
case (DimShuffle(is_left_expanded_matrix_transpose=True), A_T):
116+
A_solve_clients_and_transpose.extend(
117+
(client, True) for client in find_solve_clients(A_T, assume_a)
118+
)
122119

123120
if not eager and len(A_solve_clients_and_transpose) == 1:
124121
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
@@ -171,34 +168,34 @@ def _scan_split_non_sequence_decomposition_and_solve(
171168
changed = False
172169
while True:
173170
for inner_node in new_scan_fgraph.toposort():
174-
if (
175-
isinstance(inner_node.op, Blockwise)
176-
and isinstance(inner_node.op.core_op, Solve)
177-
and inner_node.op.core_op.assume_a in allowed_assume_a
178-
):
179-
A, _b = inner_node.inputs
180-
if all(
181-
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
182-
for root_inp in graph_inputs([A])
171+
match (inner_node.op, *inner_node.inputs):
172+
case (Blockwise(Solve(assume_a=assume_a_var)), A, _b) if (
173+
assume_a_var in allowed_assume_a
183174
):
184-
if new_scan_fgraph is scan_op.fgraph:
185-
# Clone the first time to avoid mutating the original fgraph
186-
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
187-
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
188-
inner_node = equiv[inner_node] # type: ignore
189-
190-
replace_dict = _split_decomp_and_solve_steps(
191-
new_scan_fgraph,
192-
inner_node,
193-
eager=True,
194-
allowed_assume_a=allowed_assume_a,
195-
)
196-
assert isinstance(replace_dict, dict) and len(replace_dict) > 0, (
197-
"Rewrite failed"
198-
)
199-
new_scan_fgraph.replace_all(replace_dict.items())
200-
changed = True
201-
break # Break to start over with a fresh toposort
175+
if all(
176+
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
177+
for root_inp in graph_inputs([A])
178+
):
179+
if new_scan_fgraph is scan_op.fgraph:
180+
# Clone the first time to avoid mutating the original fgraph
181+
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
182+
non_sequences = {
183+
equiv[non_seq] for non_seq in non_sequences
184+
}
185+
inner_node = equiv[inner_node] # type: ignore
186+
187+
replace_dict = _split_decomp_and_solve_steps(
188+
new_scan_fgraph,
189+
inner_node,
190+
eager=True,
191+
allowed_assume_a=allowed_assume_a,
192+
)
193+
assert (
194+
isinstance(replace_dict, dict) and len(replace_dict) > 0
195+
), "Rewrite failed"
196+
new_scan_fgraph.replace_all(replace_dict.items())
197+
changed = True
198+
break # Break to start over with a fresh toposort
202199
else: # no_break
203200
break # Nothing else changed
204201

0 commit comments

Comments
 (0)