Skip to content

Commit b68c74d

Browse files
committed
Small tweaks to FusionOptimizer
1 parent 1d825dd commit b68c74d

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -652,12 +652,12 @@ def elemwise_scalar_op_has_c_code(
652652
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
653653
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
654654
# Root variables have `None` as owner, which we can handle with a bitset of 0
655-
ancestors_bitset = {None: 0}
655+
ancestors_bitsets = {None: 0}
656656
for node, node_bitflag in nodes_bitflags.items():
657657
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
658-
ancestors_bitset[node] = reduce(
658+
ancestors_bitsets[node] = reduce(
659659
or_,
660-
(ancestors_bitset[inp.owner] for inp in node.inputs),
660+
(ancestors_bitsets[inp.owner] for inp in node.inputs),
661661
node_bitflag,
662662
)
663663
# Handle root and leaf nodes gracefully
@@ -666,10 +666,12 @@ def elemwise_scalar_op_has_c_code(
666666
nodes_bitflags[None] = 0
667667
# Nothing ever depends on the special Output nodes, so just use a new bit for all of them
668668
out_bitflag = 1 << len(nodes_bitflags)
669-
for out in fg.outputs:
670-
for client, _ in fg_clients[out]:
671-
if isinstance(client.op, Output):
672-
nodes_bitflags[client] = out_bitflag
669+
nodes_bitflags |= (
670+
(client, out_bitflag)
671+
for out in fg.outputs
672+
for client, _ in fg_clients[out]
673+
if isinstance(client.op, Output)
674+
)
673675

674676
# Start main loop to find collection of fuseable subgraphs
675677
# We store the collection in `sorted_subgraphs`, in reverse topological order
@@ -726,7 +728,7 @@ def elemwise_scalar_op_has_c_code(
726728
if node_bitflag & unfuseable_ancestors_bitset:
727729
# An unfuseable ancestor of the subgraph depends on this node, can't fuse
728730
continue
729-
elif ancestors_bitset[node] & unfuseable_clients_bitset:
731+
elif ancestors_bitsets[node] & unfuseable_clients_bitset:
730732
# This node depends on an unfuseable client of the subgraph, can't fuse
731733
continue
732734

@@ -742,7 +744,7 @@ def elemwise_scalar_op_has_c_code(
742744
for inp in node.inputs:
743745
ancestor_node = inp.owner
744746
ancestor_bitflag = nodes_bitflags[ancestor_node]
745-
if ancestor_bitflag & subgraph_bitset:
747+
if (not is_ancestor) and (ancestor_bitflag & subgraph_bitset):
746748
continue
747749
if node in fuseable_clients.get(ancestor_node, ()):
748750
heappush(
@@ -752,14 +754,14 @@ def elemwise_scalar_op_has_c_code(
752754
else:
753755
# If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
754756
# nor with any of the ancestor's ancestors
755-
unfuseable_ancestors_bitset |= ancestors_bitset[
757+
unfuseable_ancestors_bitset |= ancestors_bitsets[
756758
ancestor_node
757759
]
758760

759761
next_fuseable_clients = fuseable_clients.get(node, ())
760762
for client, _ in fg_clients[node.outputs[0]]:
761763
client_bitflag = nodes_bitflags[client]
762-
if client_bitflag & subgraph_bitset:
764+
if is_ancestor and (client_bitflag & subgraph_bitset):
763765
continue
764766
if client in next_fuseable_clients:
765767
heappush(fuseables_nodes_queue, (client_bitflag, client))

0 commit comments

Comments
 (0)