Skip to content

Commit 74e9d12

Browse files
anna-grimanna-grim
andauthored
Feat s3 loading (#178)
* refactor: improved txt reader * remove print * feat: load swcs from s3 * feat: read json * bug: read local txt * bug: swc loading, merge sites * bug: merge site names --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent cf96ce7 commit 74e9d12

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/segmentation_skeleton_metrics/data_handling/graph_loading.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def load_groundtruth(self, swc_pointer, label_mask):
8989
use_anisotropy=False,
9090
verbose=self.verbose
9191
)
92-
return graph_loader.run(swc_pointer)
92+
return graph_loader(swc_pointer)
9393

9494
def load_fragments(self, swc_pointer, gt_graphs):
9595
"""
@@ -124,7 +124,7 @@ def load_fragments(self, swc_pointer, gt_graphs):
124124
use_anisotropy=self.use_anisotropy,
125125
verbose=self.verbose
126126
)
127-
return graph_loader.run(swc_pointer)
127+
return graph_loader(swc_pointer)
128128

129129
# --- Helpers ---
130130
def get_all_node_labels(self, graphs):
@@ -198,7 +198,7 @@ def __init__(
198198
anisotropy, selected_ids=selected_ids
199199
)
200200

201-
def run(self, swc_pointer):
201+
def __call__(self, swc_pointer):
202202
"""
203203
Builds a graphs by reading SWC files to extract content to load into a
204204
SkeletonGraph object. Nodes are labeled if a label_mask is provided.

src/segmentation_skeleton_metrics/skeleton_metrics.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ def verify_site(self, gt_graph, fragment_graph, gt_node, fragment_node):
525525
self.fragments_with_merge.add(fragment_graph.name)
526526
self.merge_sites.append(
527527
{
528+
"Fragment_Name": fragment_graph.name,
528529
"Segment_ID": fragment_graph.segment_id,
529530
"GroundTruth_ID": gt_graph.name,
530531
"Voxel": tuple(map(int, voxel)),
@@ -931,17 +932,17 @@ def __call__(self, gt_graphs, fragment_graphs, merge_sites):
931932
pair_to_length = dict()
932933
for i in merge_sites.index:
933934
# Extract site info
934-
segment_id = merge_sites["Segment_ID"][i]
935+
fragment_name = str(merge_sites["Fragment_Name"][i])
935936
gt_id = merge_sites["GroundTruth_ID"][i]
936-
pair_id = (segment_id, gt_id)
937+
pair_id = (fragment_name, gt_id)
937938

938939
# Check wheter to visit
939940
if pair_id in pair_to_length:
940941
merge_sites.loc[i, self.name] = pair_to_length[pair_id]
941942
else:
942943
# Get graphs
943944
gt_graph = gt_graphs[gt_id]
944-
fragment_graph = deepcopy(fragment_graphs[segment_id])
945+
fragment_graph = deepcopy(fragment_graphs[fragment_name])
945946

946947
# Compute metric
947948
pair_to_length[pair_id] = self.compute_added_length(

0 commit comments

Comments
 (0)