1717 Partitioner ,
1818 PartitionResult ,
1919)
20+
21+ from sympy .logic .boolalg import disjuncts
22+ from torch ._export .utils import is_buffer , is_lifted_tensor_constant , is_param
2023from torch .fx .passes .infra .partitioner import Partition
2124
2225
26+ def is_constant_data (ep : ExportedProgram , node : torch .fx .Node ) -> bool :
27+ return (
28+ is_param (ep , node ) or is_buffer (ep , node ) or is_lifted_tensor_constant (ep , node )
29+ )
30+
31+
2332def format_target_name (target_name : str ) -> str :
2433 """
2534 We remove the dialect name space from the target name. We generally
@@ -100,6 +109,35 @@ def get_partition(
100109 pass
101110
102111
112+ class DSJ :
113+ """
114+ Disjoint set union data structure used to find connected components in the graph.
115+ """
116+
117+ def __init__ (self ):
118+ self .parent = {}
119+
120+ def find (self , x ):
121+ self .parent .setdefault (x , x )
122+ if self .parent [x ] != x :
123+ self .parent [x ] = self .find (self .parent [x ])
124+ return self .parent [x ]
125+
126+ def union (self , x , y ):
127+ self .parent [self .find (x )] = self .find (y )
128+
129+ def contains (self , x ):
130+ return x in self .parent
131+
132+ def gen_groups (self ):
133+ groups = {}
134+ for node in self .parent .keys ():
135+ root = self .find (node )
136+ groups .setdefault (root , set ()).add (node )
137+
138+ return [list (group ) for group in groups .values ()]
139+
140+
103141class ConfigerationBasedPartitioner (Partitioner ):
104142 def __init__ (
105143 self ,
@@ -162,17 +200,8 @@ def filter_fn(node: torch.fx.Node) -> bool:
162200 def get_matched_nodes_from_configs (
163201 self , ep : ExportedProgram
164202 ) -> List [List [torch .fx .Node ]]:
165- # disjoint set union
166- parent = {}
167-
168- def find (x ):
169- parent .setdefault (x , x )
170- if parent [x ] != x :
171- parent [x ] = find (parent [x ])
172- return parent [x ]
173-
174- def union (x , y ):
175- parent [find (x )] = find (y )
203+ # disjoint set union for merging partitions
204+ dsj = DSJ ()
176205
177206 # gather supported nodes
178207 gm = ep .graph_module
@@ -188,18 +217,22 @@ def union(x, y):
188217 if not node_config .check_constraints (node , ep ):
189218 continue
190219
191- partition = node_config .get_partition (node , ep )
220+ partition_candidate = node_config .get_partition (node , ep )
221+ partition = []
222+ for node in partition_candidate :
223+ # partitioner infra copies constant data across partitions, so it
224+ # is ok if this partition doesn't have it
225+ if is_constant_data (ep , node ) and dsj .contains (node ):
226+ continue
227+ partition .append (node )
228+
229+ # Union overlaps into a single group
192230 if len (partition ) > 0 :
193- parent [ partition [0 ]] = partition [ 0 ]
231+ dsj . find ( partition [0 ])
194232 for i in range (1 , len (partition )):
195- union (partition [0 ], partition [i ])
233+ dsj . union (partition [0 ], partition [i ])
196234
197- groups = {}
198- for node in parent .keys ():
199- root = find (node )
200- groups .setdefault (root , set ()).add (node )
201-
202- return [list (group ) for group in groups .values ()]
235+ return dsj .gen_groups ()
203236
204237 def generate_partitions (self , ep : ExportedProgram ) -> List [Partition ]:
205238 matched_nodes = self .get_matched_nodes_from_configs (ep )
0 commit comments