@@ -99,6 +99,20 @@ def init_target_graphs(self, paths, anisotropy):
9999 self .target_graphs [swc_id ] = to_graph (path , anisotropy = anisotropy )
100100
101101 def init_pred_graphs (self ):
102+ """
103+ Initializes "self.pred_graphs" by copying each graph in
104+ "self.target_graphs", then labels each node with the label in
105+ "self.labels" that coincides with it.
106+
107+ Parameters
108+ ----------
109+ None
110+
111+ Returns
112+ -------
113+ None
114+
115+ """
102116 print ("Labelling Target Graphs..." )
103117 t0 = time ()
104118 self .pred_graphs = dict ()
@@ -227,7 +241,8 @@ def compute_metrics(self):
227241 self .quantify_merges ()
228242
229243 # Compute metrics
230- self .compile_results ()
244+ full_results , avg_results = self .compile_results ()
245+ return full_results , avg_results
231246
232247 def detect_splits (self ):
233248 """
@@ -397,9 +412,39 @@ def detect_merges(self):
397412 print (f"\n Runtime: { round (t , 2 )} { unit } \n " )
398413
399414 def init_merge_counter (self ):
415+ """
416+ Initializes a dictionary that is used to count the number of merge
417+ type mistakes for each pred_graph.
418+
419+ Parameters
420+ ----------
421+ None
422+
423+ Returns
424+ -------
425+ dict
426+ Dictionary used to count number of merge type mistakes.
427+
428+ """
400429 return dict ([(swc_id , 0 ) for swc_id in self .pred_graphs .keys ()])
401430
402431 def process_merge (self , swc_id , label ):
432+ """
433+ Once a merge has been detected that corresponds to "label", every node
434+ in "self.pred_graph[swc_id]" with that label is deleted.
435+
436+ Parameters
437+ ----------
438+ swc_id : str
439+ Key associated with the pred_graph to be searched.
440+ label : int
441+ Label assocatied with a merge.
442+
443+ Returns
444+ -------
445+ None
446+
447+ """
403448 # Update graph
404449 graph = self .pred_graphs [swc_id ].copy ()
405450 graph , merged_cnt = gutils .delete_nodes (graph , label , return_cnt = True )
@@ -411,29 +456,124 @@ def process_merge(self, swc_id, label):
411456 self .merged_cnts [swc_id ] += merged_cnt
412457
413458 def quantify_merges (self ):
459+ """
460+ Computes the percentage of merged edges for each pred_graph.
461+
462+ Parameters
463+ ----------
464+ None
465+
466+ Returns
467+ -------
468+ None
469+
470+ """
414471 self .merged_percents = dict ()
415472 for swc_id in self .target_graphs .keys ():
416473 n_edges = self .target_graphs [swc_id ].number_of_edges ()
417474 self .merged_percents [swc_id ] = self .merged_cnts [swc_id ] / n_edges
418475
419476 def compile_results (self ):
477+ """
478+ Compiles a dictionary containing the metrics computed by this module.
479+
480+ Parameters
481+ ----------
482+ None
483+
484+ Returns
485+ -------
486+ full_results : dict
487+ Dictionary where the keys are swc_ids and the values are the result
488+ of computing each metric for the corresponding graphs.
489+ avg_result : dict
490+ Dictionary where the keys are names of metrics computed by this
491+ module and values are the averaged result over all swc_ids.
492+
493+ """
420494 # Compute remaining metrics
421495 self .compute_edge_accuracy ()
422496 self .compute_erl ()
423497
498+ # Summarize results
499+ swc_ids , full_results = self .generate_report ()
500+ avg_results = dict ([(k , np .mean (v )) for k , v in full_results .items ()])
501+ full_results = dict (zip (swc_ids , full_results ))
502+ return full_results , avg_results
503+
504+ def generate_report (self ):
505+ """
506+ Generates a report by creating a list of the results for each metric.
507+ Each item in this list corresponds to a graph in "self.pred_graphs"
508+ and this list is ordered with respect to "swc_ids".
509+
510+ Parameters
511+ ----------
512+ None
513+
514+ Results
515+ -------
516+ swc_ids : list[str]
517+ Specifies the ordering of results for each value in "stats".
518+ stats : dict
519+ Dictionary where the keys are metrics and values are the result of
520+ computing that metric for each graph in "self.pred_graphs".
521+
522+ """
523+ swc_ids = list (self .pred_graphs .keys ())
524+ swc_ids .sort ()
525+ stats = {
526+ "# splits" : generate_result (swc_ids , self .split_cnts ),
527+ "# merges" : generate_result (swc_ids , self .merge_cnts ),
528+ "% omit edges" : generate_result (swc_ids , self .omit_percents ),
529+ "% merged edges" : generate_result (swc_ids , self .merged_percents ),
530+ "edge accuracy" : generate_result (swc_ids , self .edge_accuracy ),
531+ "erl" : generate_result (swc_ids , self .erl ),
532+ "normalized erl" : generate_result (swc_ids , self .normalized_erl ),
533+ }
534+ return swc_ids , stats
535+
424536 def compute_edge_accuracy (self ):
537+ """
538+ Computes the edge accuracy of each pred_graph.
539+
540+ Parameters
541+ ----------
542+ None
543+
544+ Returns
545+ -------
546+ None
547+
548+ """
425549 self .edge_accuracy = dict ()
426550 for swc_id in self .target_graphs .keys ():
427551 omit_percent = self .omit_percents [swc_id ]
428552 merged_percent = self .merged_percents [swc_id ]
429553 self .edge_accuracy [swc_id ] = 1 - omit_percent - merged_percent
430554
431555 def compute_erl (self ):
556+ """
557+ Computes the expected run length (ERL) of each pred_graph.
558+
559+ Parameters
560+ ----------
561+ None
562+
563+ Returns
564+ -------
565+ None
566+
567+ """
432568 self .erl = dict ()
569+ self .normalized_erl = dict ()
433570 for swc_id in self .target_graphs .keys ():
434- graph = self .pred_graphs [swc_id ]
435- path_lengths = gutils .compute_run_lengths (graph )
571+ pred_graph = self .pred_graphs [swc_id ]
572+ target_graph = self .target_graphs [swc_id ]
573+ path_lengths = gutils .compute_run_lengths (pred_graph )
574+ path_length = gutils .compute_path_length (target_graph )
436575 self .erl [swc_id ] = np .mean (path_lengths )
576+ self .normalized_erl [swc_id ] = np .mean (path_lengths ) / path_length
437577
438578
439579# -- utils --
@@ -480,3 +620,25 @@ def remove_edge(dfs_edges, edge):
480620 elif (edge [1 ], edge [0 ]) in dfs_edges :
481621 dfs_edges .remove ((edge [1 ], edge [0 ]))
482622 return dfs_edges
623+
624+
625+ def generate_result (swc_ids , stats ):
626+ """
627+ Reorders items in "stats" with respect to the order defined by "swc_ids".
628+
629+ Parameters
630+ ----------
631+ swc_ids : list[str]
632+ List of all swc_ids of graphs in "self.pred_graphs".
633+ stats : dict
634+ Dictionary where the keys are swc_ids and values are the result of
635+ computing some metrics.
636+
637+ Returns
638+ -------
639+ list
640+ Reorded items in "stats" with respect to the order defined by
641+ "swc_ids".
642+
643+ """
644+ return [stats [swc_id ] for swc_id in swc_ids ]
0 commit comments