Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def load_groundtruth(self, swc_pointer, label_mask):
use_anisotropy=False,
verbose=self.verbose
)
return graph_loader.run(swc_pointer)
return graph_loader(swc_pointer)

def load_fragments(self, swc_pointer, gt_graphs):
"""
Expand Down Expand Up @@ -124,7 +124,7 @@ def load_fragments(self, swc_pointer, gt_graphs):
use_anisotropy=self.use_anisotropy,
verbose=self.verbose
)
return graph_loader.run(swc_pointer)
return graph_loader(swc_pointer)

# --- Helpers ---
def get_all_node_labels(self, graphs):
Expand Down Expand Up @@ -198,7 +198,7 @@ def __init__(
anisotropy, selected_ids=selected_ids
)

def run(self, swc_pointer):
def __call__(self, swc_pointer):
"""
Builds a graphs by reading SWC files to extract content to load into a
SkeletonGraph object. Nodes are labeled if a label_mask is provided.
Expand Down
7 changes: 4 additions & 3 deletions src/segmentation_skeleton_metrics/skeleton_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def verify_site(self, gt_graph, fragment_graph, gt_node, fragment_node):
self.fragments_with_merge.add(fragment_graph.name)
self.merge_sites.append(
{
"Fragment_Name": fragment_graph.name,
"Segment_ID": fragment_graph.segment_id,
"GroundTruth_ID": gt_graph.name,
"Voxel": tuple(map(int, voxel)),
Expand Down Expand Up @@ -931,17 +932,17 @@ def __call__(self, gt_graphs, fragment_graphs, merge_sites):
pair_to_length = dict()
for i in merge_sites.index:
# Extract site info
segment_id = merge_sites["Segment_ID"][i]
fragment_name = str(merge_sites["Fragment_Name"][i])
gt_id = merge_sites["GroundTruth_ID"][i]
pair_id = (segment_id, gt_id)
pair_id = (fragment_name, gt_id)

# Check wheter to visit
if pair_id in pair_to_length:
merge_sites.loc[i, self.name] = pair_to_length[pair_id]
else:
# Get graphs
gt_graph = gt_graphs[gt_id]
fragment_graph = deepcopy(fragment_graphs[segment_id])
fragment_graph = deepcopy(fragment_graphs[fragment_name])

# Compute metric
pair_to_length[pair_id] = self.compute_added_length(
Expand Down
Loading