@@ -123,7 +123,7 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
123
123
124
124
return True
125
125
126
- def _process_node_groups (
126
+ def _process_all_nodes (
127
127
self ,
128
128
new_partition_id ,
129
129
partitions_by_id ,
@@ -133,97 +133,60 @@ def _process_node_groups(
133
133
partition_users ,
134
134
partition_map ,
135
135
):
136
- """Process nodes in predefined groups."""
137
- group_to_partition_id = {}
138
-
139
- if not self .node_groups :
140
- return group_to_partition_id
141
-
142
- processed_nodes = set ()
143
-
144
- # We have to create the partitions in reverse topological order
145
- # so we find the groups as we traverse backwards in the graph
146
- # this likely needs to be combined with the process_remaining_nodes
147
- # TODO: this currently doesn't work with _process_remaining_nodes so
148
- # if a user provides grouped nodes with operatorsupport, then this will
149
- # faile
136
+ """Process nodes into a partition."""
150
137
for node in reversed (self .graph_module .graph .nodes ):
151
- if node not in self .node_to_group :
138
+ if node in assignment or not self ._is_node_supported ( node ) :
152
139
continue
153
140
154
- if node in processed_nodes :
155
- continue
141
+ if node in self .all_nodes_in_groups :
142
+ group_idx = self .node_to_group [node ]
143
+ group = self .node_groups [group_idx ]
156
144
157
- group_idx = self .node_to_group [node ]
158
- group = self .node_groups [group_idx ]
159
-
160
- # Create a partition for group
161
- partition_id = next (new_partition_id )
162
- partition = Partition (id = partition_id , nodes = set ())
163
- partitions_by_id [partition_id ] = partition
164
- partitions_order [partition_id ] = partition_id
165
- group_to_partition_id [group_idx ] = partition_id
166
-
167
- # Add all supported nodes from the group to the partition
168
- for node in group :
169
- if self ._is_node_supported (node ):
170
- partition .add_node (node )
171
- assignment [node ] = partition_id
172
- nodes_order [node ] = partition_id
173
-
174
- # Set partition users
175
- partition_users [partition_id ] = {
176
- user
177
- for node in partition .nodes
178
- for user in node .users
179
- if user not in partition .nodes
180
- }
181
-
182
- # Update partition map
183
- for node in partition .nodes :
145
+ # Create a partition for group
146
+ partition_id = next (new_partition_id )
147
+ partition = Partition (id = partition_id , nodes = set ())
148
+ partitions_by_id [partition_id ] = partition
149
+ partitions_order [partition_id ] = partition_id
150
+
151
+ # Add all supported nodes from the group to the partition
152
+ for node in group :
153
+ if self ._is_node_supported (node ):
154
+ partition .add_node (node )
155
+ assignment [node ] = partition_id
156
+ nodes_order [node ] = partition_id
157
+
158
+ # Set partition users
159
+ partition_users [partition_id ] = {
160
+ user
161
+ for node in partition .nodes
162
+ for user in node .users
163
+ if user not in partition .nodes
164
+ }
165
+
166
+ # Update partition map
167
+ for node in partition .nodes :
168
+ for user in node .users :
169
+ target_id = assignment .get (user , None )
170
+ if target_id is not None and target_id != partition_id :
171
+ partition_map [partition_id ].add (target_id )
172
+ partition_map [partition_id ].update (partition_map [target_id ])
173
+ else :
174
+ partition_id = next (new_partition_id )
175
+ nodes_order [node ] = partition_id
176
+ partitions_order [partition_id ] = partition_id
177
+ partitions_by_id [partition_id ] = Partition (
178
+ id = partition_id , nodes = [node ]
179
+ )
180
+ assignment [node ] = partition_id
181
+ partition_users [partition_id ] = set (node .users )
182
+
183
+ # Update partition map
184
184
for user in node .users :
185
185
target_id = assignment .get (user )
186
- if target_id is not None and target_id != partition_id :
186
+ if target_id is not None :
187
187
partition_map [partition_id ].add (target_id )
188
188
partition_map [partition_id ].update (partition_map [target_id ])
189
189
190
- # all the nodes in the group have now been processed
191
- # so skip if we encoutner them again in our rev topo
192
- # iteration
193
- for node in group :
194
- processed_nodes .add (node )
195
-
196
- return group_to_partition_id
197
-
198
- def _process_remaining_nodes (
199
- self ,
200
- new_partition_id ,
201
- partitions_by_id ,
202
- assignment ,
203
- nodes_order ,
204
- partitions_order ,
205
- partition_users ,
206
- partition_map ,
207
- ):
208
- """Process nodes not in any predefined group."""
209
- for node in reversed (self .graph_module .graph .nodes ):
210
- if node in assignment or not self ._is_node_supported (node ):
211
- continue
212
-
213
- partition_id = next (new_partition_id )
214
- nodes_order [node ] = partition_id
215
- partitions_order [partition_id ] = partition_id
216
- partitions_by_id [partition_id ] = Partition (id = partition_id , nodes = [node ])
217
- assignment [node ] = partition_id
218
- partition_users [partition_id ] = set (node .users )
219
-
220
- # Update partition map
221
- for user in node .users :
222
- target_id = assignment .get (user )
223
- if target_id is not None :
224
- partition_map [partition_id ].add (target_id )
225
- partition_map [partition_id ].update (partition_map [target_id ])
226
-
227
190
def _merge_partitions (
228
191
self ,
229
192
partitions_by_id ,
@@ -378,19 +341,8 @@ def propose_partitions(self) -> list[Partition]:
378
341
partition_users = {} # Maps partition IDs to partition users
379
342
new_partition_id = itertools .count ()
380
343
381
- # Process nodes in predefined groups
382
- self ._process_node_groups (
383
- new_partition_id ,
384
- partitions_by_id ,
385
- assignment ,
386
- nodes_order ,
387
- partitions_order ,
388
- partition_users ,
389
- partition_map ,
390
- )
391
-
392
- # Process remaining nodes
393
- self ._process_remaining_nodes (
344
+ # Process all nodes into partitions
345
+ self ._process_all_nodes (
394
346
new_partition_id ,
395
347
partitions_by_id ,
396
348
assignment ,
0 commit comments