Skip to content

Commit b6a9b68

Browse files
committed
Fix FusionOptimizer bug
When a subgraph with multiple outputs is "implicitly" claimed, it can change the dependencies of remaining nodes. A node that depended only on a subset of the subgraph outputs now depends on all of them. Not taking this into account could lead to circular dependent Composites
1 parent b68c74d commit b6a9b68

File tree

2 files changed

+83
-16
lines changed

2 files changed

+83
-16
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def elemwise_scalar_op_has_c_code(
652652
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
653653
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
654654
# Root variables have `None` as owner, which we can handle with a bitset of 0
655-
ancestors_bitsets = {None: 0}
655+
ancestors_bitsets: dict[Apply | None, int] = {None: 0}
656656
for node, node_bitflag in nodes_bitflags.items():
657657
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
658658
ancestors_bitsets[node] = reduce(
@@ -694,9 +694,13 @@ def elemwise_scalar_op_has_c_code(
694694
# For simplicity, we always want to visit ancestors before clients
695695
# For ancestors, we want to visit the later nodes first (those that have more dependencies)
696696
# whereas for clients we want to visit earlier nodes first (those that have fewer dependencies)
697-
# To achieve this we use the bitflag as the sorting key (which encodes the topological order)
698-
# and negate it for ancestors.
699-
fuseables_nodes_queue = [(-starting_bitflag, starting_node)]
697+
# To achieve this we use the ancestors_bitset as the sorting key (which encodes the topological order)
698+
# and negate it for ancestors. We use the ancestors_bitset instead of the node bitflag because we
699+
# update the former when we find a fuseable subgraph, emulating the effect of recomputing the
700+
# topological order on the remaining nodes.
701+
fuseables_nodes_queue = [
702+
(-ancestors_bitsets[starting_node], starting_bitflag, starting_node)
703+
]
700704
heapify(fuseables_nodes_queue)
701705

702706
# We keep 3 bitsets during the exploration of a new subgraph:
@@ -715,10 +719,12 @@ def elemwise_scalar_op_has_c_code(
715719
unfuseable_clients_bitset = 0
716720

717721
while fuseables_nodes_queue:
718-
node_bitflag, node = heappop(fuseables_nodes_queue)
719-
is_ancestor = node_bitflag < 0
722+
node_ancestors_bitset, node_bitflag, node = heappop(
723+
fuseables_nodes_queue
724+
)
725+
is_ancestor = node_ancestors_bitset < 0
720726
if is_ancestor:
721-
node_bitflag = -node_bitflag
727+
node_ancestors_bitset = -node_ancestors_bitset
722728

723729
if node_bitflag & subgraph_bitset:
724730
# Already part of the subgraph
@@ -728,7 +734,7 @@ def elemwise_scalar_op_has_c_code(
728734
if node_bitflag & unfuseable_ancestors_bitset:
729735
# An unfuseable ancestor of the subgraph depends on this node, can't fuse
730736
continue
731-
elif ancestors_bitsets[node] & unfuseable_clients_bitset:
737+
elif node_ancestors_bitset & unfuseable_clients_bitset:
732738
# This node depends on an unfuseable client of the subgraph, can't fuse
733739
continue
734740

@@ -749,7 +755,11 @@ def elemwise_scalar_op_has_c_code(
749755
if node in fuseable_clients.get(ancestor_node, ()):
750756
heappush(
751757
fuseables_nodes_queue,
752-
(-ancestor_bitflag, ancestor_node),
758+
(
759+
-ancestors_bitsets[ancestor_node],
760+
ancestor_bitflag,
761+
ancestor_node,
762+
),
753763
)
754764
else:
755765
# If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
@@ -764,16 +774,17 @@ def elemwise_scalar_op_has_c_code(
764774
if is_ancestor and (client_bitflag & subgraph_bitset):
765775
continue
766776
if client in next_fuseable_clients:
767-
heappush(fuseables_nodes_queue, (client_bitflag, client))
777+
heappush(
778+
fuseables_nodes_queue,
779+
(ancestors_bitsets[client], client_bitflag, client),
780+
)
768781
else:
769782
# If a client is not in the node's fuseable clients set, it's nto fuseable with it,
770783
# nor any of its clients. But we don't need to keep track of those as any downstream
771784
# client we may consider later will also depend on this unfuseable client and be rejected
772785
unfuseable_clients_bitset |= client_bitflag
773786

774-
# Finished exploring this subgraph
775-
all_subgraphs_bitset |= subgraph_bitset
776-
787+
# Finished expansion of subgraph
777788
if subgraph_bitset == starting_bitflag:
778789
# We ended were we started, no fusion possible
779790
continue
@@ -816,6 +827,18 @@ def elemwise_scalar_op_has_c_code(
816827
for out in subgraph_outputs:
817828
fuseable_clients.pop(out.owner, None)
818829

830+
# When we fuse multi-output subgraphs, we also need to fuse the dependencies of successor nodes.
831+
# Nodes that previously depended on a subset of the fused outputs, now depend on all of them.
832+
if len(subgraph_outputs) > 1:
833+
subgraph_and_ancestors = (
834+
subgraph_bitset | unfuseable_ancestors_bitset
835+
)
836+
ancestors_bitsets |= (
837+
(node, node_ancestors_bitset | subgraph_and_ancestors)
838+
for node, node_ancestors_bitset in ancestors_bitsets.items()
839+
if node_ancestors_bitset & subgraph_bitset
840+
)
841+
819842
# Add new subgraph to sorted_subgraphs
820843
# Because we start from sink nodes in reverse topological order, most times new subgraphs
821844
# don't depend on previous subgraphs, so we can just append them at the end.
@@ -828,8 +851,7 @@ def elemwise_scalar_op_has_c_code(
828851
else:
829852
# But not here, so we need to find the right position for insertion.
830853
# We iterate through the previous subgraphs in topological order (reverse of the stored order).
831-
# We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again.
832-
# The (index + 1) of the firs iteration where the check passes is the correct insertion position.
854+
# We cumulatively exclude each subgraph_bitset and perform the same dependency check again, until it passes.
833855
remaining_subgraphs_bitset = all_subgraphs_bitset
834856
for index, (other_subgraph_bitset, _) in enumerate(
835857
reversed(sorted_subgraphs)
@@ -840,12 +862,20 @@ def elemwise_scalar_op_has_c_code(
840862
unfuseable_ancestors_bitset & remaining_subgraphs_bitset
841863
):
842864
break # bingo
865+
else: # no-break
866+
raise RuntimeError(
867+
"Failed to find insertion point for fused subgraph"
868+
)
843869
sorted_subgraphs.insert(
844870
-(index + 1),
845871
(subgraph_bitset, (subgraph_inputs, subgraph_outputs)),
846872
)
847873

848-
# yield from sorted_subgraphs, discarding the subgraph_bitset
874+
# Add subgraph to all_subgraphs_bitset
875+
all_subgraphs_bitset |= subgraph_bitset
876+
877+
# Finished exploring the whole graph
878+
# Yield from sorted_subgraphs, discarding the subgraph_bitset
849879
yield from (io for _, io in sorted_subgraphs)
850880

851881
max_operands = elemwise_max_operands_fct(None)

tests/tensor/rewriting/test_elemwise.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,43 @@ def test_no_warning_from_old_client(self):
14261426
np.log(1 - np.exp(-2)),
14271427
)
14281428

1429+
def test_joint_circular_dependency(self):
1430+
# Test a case where fused subgraphs could induce a circular dependency
1431+
x = matrix("x")
1432+
neg = pt.neg(x)
1433+
eq = pt.eq(x.sum(axis=0), 0)
1434+
sub = pt.sub(eq, neg)
1435+
exp = pt.exp(neg.sum(axis=0))
1436+
# We test arbitrary add and output orders, to make sure our algorithm
1437+
# is robust to valid toposort variations.
1438+
for add_order in [(exp, eq), (eq, exp)]:
1439+
add = pt.add(*add_order)
1440+
1441+
# The naive fused graphs to consider are {sub, neg} and {add, exp, eq},
1442+
# which is not valid because sub depends on eq, while add/exp depends on neg.
1443+
# Instead, we can either fuse both {sub, neg} and {add, exp} or just {add, exp, eq}
1444+
1445+
for out_order in [(sub, add), (add, sub)]:
1446+
fgraph = FunctionGraph([x], out_order, clone=True)
1447+
_, nb_fused, nb_replaced, *_ = FusionOptimizer().apply(fgraph)
1448+
# (nb_fused, nb_replaced) would be (2, 5) if we did the invalid fusion
1449+
assert (nb_fused, nb_replaced) in ((2, 4), (1, 3))
1450+
fused_nodes = {
1451+
frozenset(
1452+
scalar_n.op for scalar_n in n.op.scalar_op.fgraph.apply_nodes
1453+
)
1454+
for n in fgraph.apply_nodes
1455+
if isinstance(n.op, Elemwise)
1456+
and isinstance(n.op.scalar_op, Composite)
1457+
}
1458+
if nb_fused == 1:
1459+
assert fused_nodes == {frozenset((ps.add, ps.exp, ps.eq))}
1460+
else:
1461+
assert fused_nodes == {
1462+
frozenset((ps.sub, ps.neg)),
1463+
frozenset((ps.add, ps.exp)),
1464+
}
1465+
14291466

14301467
class TimesN(ps.basic.UnaryScalarOp):
14311468
"""

0 commit comments

Comments
 (0)