@@ -652,7 +652,7 @@ def elemwise_scalar_op_has_c_code(
652
652
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
653
653
nodes_bitflags = {node : 1 << i for i , node in enumerate (fgraph .toposort ())}
654
654
# Root variables have `None` as owner, which we can handle with a bitset of 0
655
- ancestors_bitsets = {None : 0 }
655
+ ancestors_bitsets : dict [ Apply | None , int ] = {None : 0 }
656
656
for node , node_bitflag in nodes_bitflags .items ():
657
657
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
658
658
ancestors_bitsets [node ] = reduce (
@@ -694,9 +694,13 @@ def elemwise_scalar_op_has_c_code(
694
694
# For simplicity, we always want to visit ancestors before clients
695
695
# For ancestors, we want to visit the later nodes first (those that have more dependencies)
696
696
# whereas for clients we want to visit earlier nodes first (those that have fewer dependencies)
697
- # To achieve this we use the bitflag as the sorting key (which encodes the topological order)
698
- # and negate it for ancestors.
699
- fuseables_nodes_queue = [(- starting_bitflag , starting_node )]
697
+ # To achieve this we use the ancestors_bitset as the sorting key (which encodes the topological order)
698
+ # and negate it for ancestors. We use the ancestors_bitset instead of the node bitflag because we
699
+ # update the former when we find a fuseable subgraph, emulating the effect of recomputing the
700
+ # topological order on the remaining nodes.
701
+ fuseables_nodes_queue = [
702
+ (- ancestors_bitsets [starting_node ], starting_bitflag , starting_node )
703
+ ]
700
704
heapify (fuseables_nodes_queue )
701
705
702
706
# We keep 3 bitsets during the exploration of a new subgraph:
@@ -715,10 +719,12 @@ def elemwise_scalar_op_has_c_code(
715
719
unfuseable_clients_bitset = 0
716
720
717
721
while fuseables_nodes_queue :
718
- node_bitflag , node = heappop (fuseables_nodes_queue )
719
- is_ancestor = node_bitflag < 0
722
+ node_ancestors_bitset , node_bitflag , node = heappop (
723
+ fuseables_nodes_queue
724
+ )
725
+ is_ancestor = node_ancestors_bitset < 0
720
726
if is_ancestor :
721
- node_bitflag = - node_bitflag
727
+ node_ancestors_bitset = - node_ancestors_bitset
722
728
723
729
if node_bitflag & subgraph_bitset :
724
730
# Already part of the subgraph
@@ -728,7 +734,7 @@ def elemwise_scalar_op_has_c_code(
728
734
if node_bitflag & unfuseable_ancestors_bitset :
729
735
# An unfuseable ancestor of the subgraph depends on this node, can't fuse
730
736
continue
731
- elif ancestors_bitsets [ node ] & unfuseable_clients_bitset :
737
+ elif node_ancestors_bitset & unfuseable_clients_bitset :
732
738
# This node depends on an unfuseable client of the subgraph, can't fuse
733
739
continue
734
740
@@ -749,7 +755,11 @@ def elemwise_scalar_op_has_c_code(
749
755
if node in fuseable_clients .get (ancestor_node , ()):
750
756
heappush (
751
757
fuseables_nodes_queue ,
752
- (- ancestor_bitflag , ancestor_node ),
758
+ (
759
+ - ancestors_bitsets [ancestor_node ],
760
+ ancestor_bitflag ,
761
+ ancestor_node ,
762
+ ),
753
763
)
754
764
else :
755
765
# If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
@@ -764,16 +774,17 @@ def elemwise_scalar_op_has_c_code(
764
774
if is_ancestor and (client_bitflag & subgraph_bitset ):
765
775
continue
766
776
if client in next_fuseable_clients :
767
- heappush (fuseables_nodes_queue , (client_bitflag , client ))
777
+ heappush (
778
+ fuseables_nodes_queue ,
779
+ (ancestors_bitsets [client ], client_bitflag , client ),
780
+ )
768
781
else :
769
782
# If a client is not in the node's fuseable clients set, it's nto fuseable with it,
770
783
# nor any of its clients. But we don't need to keep track of those as any downstream
771
784
# client we may consider later will also depend on this unfuseable client and be rejected
772
785
unfuseable_clients_bitset |= client_bitflag
773
786
774
- # Finished exploring this subgraph
775
- all_subgraphs_bitset |= subgraph_bitset
776
-
787
+ # Finished expansion of subgraph
777
788
if subgraph_bitset == starting_bitflag :
778
789
# We ended were we started, no fusion possible
779
790
continue
@@ -816,6 +827,18 @@ def elemwise_scalar_op_has_c_code(
816
827
for out in subgraph_outputs :
817
828
fuseable_clients .pop (out .owner , None )
818
829
830
+ # When we fuse multi-output subgraphs, we also need to fuse the dependencies of successor nodes.
831
+ # Nodes that previously depended on a subset of the fused outputs, now depend on all of them.
832
+ if len (subgraph_outputs ) > 1 :
833
+ subgraph_and_ancestors = (
834
+ subgraph_bitset | unfuseable_ancestors_bitset
835
+ )
836
+ ancestors_bitsets |= (
837
+ (node , node_ancestors_bitset | subgraph_and_ancestors )
838
+ for node , node_ancestors_bitset in ancestors_bitsets .items ()
839
+ if node_ancestors_bitset & subgraph_bitset
840
+ )
841
+
819
842
# Add new subgraph to sorted_subgraphs
820
843
# Because we start from sink nodes in reverse topological order, most times new subgraphs
821
844
# don't depend on previous subgraphs, so we can just append them at the end.
@@ -828,8 +851,7 @@ def elemwise_scalar_op_has_c_code(
828
851
else :
829
852
# But not here, so we need to find the right position for insertion.
830
853
# We iterate through the previous subgraphs in topological order (reverse of the stored order).
831
- # We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again.
832
- # The (index + 1) of the firs iteration where the check passes is the correct insertion position.
854
+ # We cumulatively exclude each subgraph_bitset and perform the same dependency check again, until it passes.
833
855
remaining_subgraphs_bitset = all_subgraphs_bitset
834
856
for index , (other_subgraph_bitset , _ ) in enumerate (
835
857
reversed (sorted_subgraphs )
@@ -840,12 +862,20 @@ def elemwise_scalar_op_has_c_code(
840
862
unfuseable_ancestors_bitset & remaining_subgraphs_bitset
841
863
):
842
864
break # bingo
865
+ else : # no-break
866
+ raise RuntimeError (
867
+ "Failed to find insertion point for fused subgraph"
868
+ )
843
869
sorted_subgraphs .insert (
844
870
- (index + 1 ),
845
871
(subgraph_bitset , (subgraph_inputs , subgraph_outputs )),
846
872
)
847
873
848
- # yield from sorted_subgraphs, discarding the subgraph_bitset
874
+ # Add subgraph to all_subgraphs_bitset
875
+ all_subgraphs_bitset |= subgraph_bitset
876
+
877
+ # Finished exploring the whole graph
878
+ # Yield from sorted_subgraphs, discarding the subgraph_bitset
849
879
yield from (io for _ , io in sorted_subgraphs )
850
880
851
881
max_operands = elemwise_max_operands_fct (None )
0 commit comments