Skip to content

Commit 0deb8c9

Browse files
anna-grimanna-grim
andauthored
Refactor data challenge (#144)
* bug: swc reader * refactor: updated results file --------- Co-authored-by: anna-grim <[email protected]>
1 parent 0a104b2 commit 0deb8c9

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def __init__(
129129
row_names = list(self.graphs.keys())
130130
row_names.sort()
131131
col_names = [
132+
"SWC Name",
133+
"SWC Run Length",
132134
"# Splits",
133135
"# Merges",
134136
"Split Rate",
@@ -139,9 +141,9 @@ def __init__(
139141
"Edge Accuracy",
140142
"ERL",
141143
"Normalized ERL",
142-
"GT Run Length",
143144
]
144145
self.metrics = pd.DataFrame(index=row_names, columns=col_names)
146+
self.metrics["SWC Name"] = self.metrics.index
145147

146148
# --- Load Data ---
147149
def load_groundtruth(self, swc_pointer, label_mask):
@@ -293,13 +295,13 @@ def run(self):
293295
path = f"{self.output_dir}/{prefix}results.csv"
294296
if self.fragment_graphs is None:
295297
self.metrics = self.metrics.drop("# Merges", axis=1)
296-
self.metrics.to_csv(path, index=True)
298+
self.metrics.to_csv(path, index=False)
297299

298300
# Report results
299301
path = os.path.join(self.output_dir, f"{prefix}results-overview.txt")
300302
util.update_txt(path, "Average Results...")
301303
for column_name in self.metrics.columns:
302-
if column_name != "GT Run Length":
304+
if column_name != "SWC Run Length" and column_name != "SWC Name":
303305
avg = self.compute_weighted_avg(column_name)
304306
util.update_txt(path, f" {column_name}: {avg:.4f}")
305307

@@ -345,7 +347,7 @@ def detect_splits(self):
345347
self.metrics.at[key, "Split Rate"] = rl / max(n_splits, 1)
346348
self.metrics.loc[key, "% Split Edges"] = round(p_split, 2)
347349
self.metrics.at[key, "% Omit Edges"] = round(p_omit, 2)
348-
self.metrics.loc[key, "GT Run Length"] = round(gt_rl, 2)
350+
self.metrics.loc[key, "SWC Run Length"] = round(gt_rl, 2)
349351

350352
if self.verbose:
351353
pbar.update(1)
@@ -532,9 +534,10 @@ def process_merge_sites(self):
532534
idx_mask = self.merge_sites["GroundTruth_ID"] == key
533535
n_merges = int(idx_mask.sum())
534536
rl = np.sum(self.graphs[key].run_lengths())
537+
merge_rate = rl / n_merges if n_merges > 0 else np.nan
535538

536539
self.metrics.loc[key, "# Merges"] = n_merges
537-
self.metrics.loc[key, "Merge Rate"] = rl / max(n_merges, 1)
540+
self.metrics.loc[key, "Merge Rate"] = merge_rate
538541

539542
# Save results
540543
path = os.path.join(self.output_dir, "merge_sites.csv")
@@ -626,8 +629,20 @@ def compute_erl(self):
626629
self.metrics.loc[key, "Normalized ERL"] = n_erl
627630

628631
def compute_weighted_avg(self, column_name):
629-
wgt = self.metrics["GT Run Length"]
630-
return (self.metrics[column_name] * wgt).sum() / wgt.sum()
632+
# Extract values
633+
values = self.metrics[column_name]
634+
weights = self.metrics["SWC Run Length"]
635+
636+
# Ignore NaNs
637+
mask = values.notna() & weights.notna()
638+
values = values[mask]
639+
weights = weights[mask]
640+
641+
# Compute weighted mean
642+
if weights.sum() == 0:
643+
return float("nan")
644+
else:
645+
return (values * weights).sum() / weights.sum()
631646

632647
# -- Helpers --
633648
def branch_search(self, graph, kdtree, root, radius=100):

0 commit comments

Comments
 (0)