@@ -652,12 +652,12 @@ def elemwise_scalar_op_has_c_code(
652
652
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
653
653
nodes_bitflags = {node : 1 << i for i , node in enumerate (fgraph .toposort ())}
654
654
# Root variables have `None` as owner, which we can handle with a bitset of 0
655
- ancestors_bitset = {None : 0 }
655
+ ancestors_bitsets = {None : 0 }
656
656
for node , node_bitflag in nodes_bitflags .items ():
657
657
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
658
- ancestors_bitset [node ] = reduce (
658
+ ancestors_bitsets [node ] = reduce (
659
659
or_ ,
660
- (ancestors_bitset [inp .owner ] for inp in node .inputs ),
660
+ (ancestors_bitsets [inp .owner ] for inp in node .inputs ),
661
661
node_bitflag ,
662
662
)
663
663
# Handle root and leaf nodes gracefully
@@ -666,10 +666,12 @@ def elemwise_scalar_op_has_c_code(
666
666
nodes_bitflags [None ] = 0
667
667
# Nothing ever depends on the special Output nodes, so just use a new bit for all of them
668
668
out_bitflag = 1 << len (nodes_bitflags )
669
- for out in fg .outputs :
670
- for client , _ in fg_clients [out ]:
671
- if isinstance (client .op , Output ):
672
- nodes_bitflags [client ] = out_bitflag
669
+ nodes_bitflags |= (
670
+ (client , out_bitflag )
671
+ for out in fg .outputs
672
+ for client , _ in fg_clients [out ]
673
+ if isinstance (client .op , Output )
674
+ )
673
675
674
676
# Start main loop to find collection of fuseable subgraphs
675
677
# We store the collection in `sorted_subgraphs`, in reverse topological order
@@ -726,7 +728,7 @@ def elemwise_scalar_op_has_c_code(
726
728
if node_bitflag & unfuseable_ancestors_bitset :
727
729
# An unfuseable ancestor of the subgraph depends on this node, can't fuse
728
730
continue
729
- elif ancestors_bitset [node ] & unfuseable_clients_bitset :
731
+ elif ancestors_bitsets [node ] & unfuseable_clients_bitset :
730
732
# This node depends on an unfuseable client of the subgraph, can't fuse
731
733
continue
732
734
@@ -742,7 +744,7 @@ def elemwise_scalar_op_has_c_code(
742
744
for inp in node .inputs :
743
745
ancestor_node = inp .owner
744
746
ancestor_bitflag = nodes_bitflags [ancestor_node ]
745
- if ancestor_bitflag & subgraph_bitset :
747
+ if ( not is_ancestor ) and ( ancestor_bitflag & subgraph_bitset ) :
746
748
continue
747
749
if node in fuseable_clients .get (ancestor_node , ()):
748
750
heappush (
@@ -752,14 +754,14 @@ def elemwise_scalar_op_has_c_code(
752
754
else :
753
755
# If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
754
756
# nor with any of the ancestor's ancestors
755
- unfuseable_ancestors_bitset |= ancestors_bitset [
757
+ unfuseable_ancestors_bitset |= ancestors_bitsets [
756
758
ancestor_node
757
759
]
758
760
759
761
next_fuseable_clients = fuseable_clients .get (node , ())
760
762
for client , _ in fg_clients [node .outputs [0 ]]:
761
763
client_bitflag = nodes_bitflags [client ]
762
- if client_bitflag & subgraph_bitset :
764
+ if is_ancestor and ( client_bitflag & subgraph_bitset ) :
763
765
continue
764
766
if client in next_fuseable_clients :
765
767
heappush (fuseables_nodes_queue , (client_bitflag , client ))
0 commit comments