|
6 | 6 | |
7 | 7 |
|
8 | 8 | """ |
9 | | -from collections import defaultdict |
10 | 9 | from concurrent.futures import ProcessPoolExecutor, as_completed |
11 | 10 | from random import sample |
12 | 11 | from tqdm import tqdm |
@@ -241,29 +240,7 @@ def remove_nodes(graph, target_label): |
241 | 240 | return graph |
242 | 241 |
|
243 | 242 |
|
244 | | -def init_label_to_nodes(graph): |
245 | | - """ |
246 | | - Initializes a dictionary that maps a label to nodes with that label. |
247 | | -
|
248 | | - Parameters |
249 | | - ---------- |
250 | | - graph : networkx.Graph |
251 | | - Graph to be searched. |
252 | | -
|
253 | | - Returns |
254 | | - ------- |
255 | | - dict |
256 | | - Dictionary that maps a label to nodes with that label. |
257 | | -
|
258 | | - """ |
259 | | - label_to_nodes = defaultdict(set) |
260 | | - node_to_label = nx.get_node_attributes(graph, "label") |
261 | | - for i, label in node_to_label.items(): |
262 | | - label_to_nodes[label].add(i) |
263 | | - return label_to_nodes |
264 | | - |
265 | | - |
266 | | -# -- eval tools -- |
| 243 | +# -- Miscellaneous -- |
267 | 244 | def compute_run_lengths(graph): |
268 | 245 | """ |
269 | 246 | Computes the path length of each connected component in "graph". |
@@ -327,47 +304,13 @@ def count_splits(graph): |
327 | 304 | Number of splits in "graph". |
328 | 305 |
|
329 | 306 | """ |
330 | | - return max(len(list(nx.connected_components(graph))) - 1, 0) |
| 307 | + return max(nx.number_connected_components(graph) - 1, 0) |
331 | 308 |
|
332 | 309 |
|
333 | | -# -- Miscellaneous -- |
334 | 310 | def get_segment_id(swc_id): |
335 | 311 | return int(swc_id.split(".")[0]) |
336 | 312 |
|
337 | 313 |
|
338 | | -def get_node_labels(graphs): |
339 | | - """ |
340 | | - Creates a dictionary that maps a graph id to the set of unique labels of |
341 | | - nodes in that graph. |
342 | | -
|
343 | | - Parameters |
344 | | - ---------- |
345 | | - graphs : dict |
346 | | - Graphs to be searched. |
347 | | -
|
348 | | - Returns |
349 | | - ------- |
350 | | - dict |
351 | | - Dictionary that maps a graph id to the set of unique labels of nodes |
352 | | - in that graph. |
353 | | -
|
354 | | - """ |
355 | | - with ProcessPoolExecutor() as executor: |
356 | | - # Assign processes |
357 | | - processes = list() |
358 | | - for key, graph in graphs.items(): |
359 | | - processes.append( |
360 | | - executor.submit(init_label_to_nodes, graph, True, key) |
361 | | - ) |
362 | | - |
363 | | - # Store results |
364 | | - graph_to_labels = dict() |
365 | | - for cnt, process in enumerate(as_completed(processes)): |
366 | | - key, label_to_nodes = process.result() |
367 | | - graph_to_labels[key] = set(label_to_nodes.keys()) |
368 | | - return graph_to_labels |
369 | | - |
370 | | - |
371 | 314 | def sample_leaf(graph): |
372 | 315 | """ |
373 | 316 | Samples leaf node from "graph". |
|
0 commit comments