Skip to content

Commit 97dfa85

Browse files
Implement subtype prediction WIP
1 parent 478bc0f commit 97dfa85

File tree

2 files changed

+139
-5
lines changed

2 files changed

+139
-5
lines changed

scripts/measurements/merge_sgn_segmentation.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import numpy as np
66
import zarr
7+
import z5py
8+
79
from elf.evaluation.matching import label_overlap, intersection_over_union
810
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target, get_s3_path
911
from nifty.tools import blocking
@@ -13,7 +15,7 @@
1315
def merge_segmentations(seg_a, seg_b, ids_b, offset, output_path):
1416
assert seg_a.shape == seg_b.shape
1517

16-
output_file = zarr.open(output_path, mode="a")
18+
output_file = z5py.File(output_path, mode="a")
1719
output = output_file.create_dataset("segmentation", shape=seg_a.shape, dtype=seg_a.dtype, chunks=seg_a.chunks)
1820
blocks = blocking([0, 0, 0], seg_a.shape, seg_a.chunks)
1921

@@ -63,21 +65,24 @@ def merge_sgns(cochlea, name_a, name_b, overlap_threshold=0.25):
6365
cumulative_overlap = overlap[1:, :].sum(axis=0)
6466
all_ids_b = np.unique(seg_b)
6567
ids_b = all_ids_b[cumulative_overlap < overlap_threshold]
68+
if 0 in ids_b: # Zero is likely in the ids due to the logic.
69+
ids_b = ids_b[1:]
70+
assert 0 not in ids_b
6671
offset = seg_a.max()
6772

6873
# Get the segmentations at full resolution to merge them.
69-
seg_a = get_segmentation(cochlea, seg_name=name_a, seg_key="s2")
70-
seg_b = get_segmentation(cochlea, seg_name=name_b, seg_key="s2")
74+
seg_a = get_segmentation(cochlea, seg_name=name_a, seg_key="s0")
75+
seg_b = get_segmentation(cochlea, seg_name=name_b, seg_key="s0")
7176

7277
# Write out the merged segmentations.
7378
output_folder = f"./data/{cochlea}"
7479
os.makedirs(output_folder, exist_ok=True)
75-
output_path = os.path.join(output_folder, "SGN_merged.zarr")
80+
output_path = os.path.join(output_folder, "SGN_merged.n5")
7681
merge_segmentations(seg_a, seg_b, ids_b, offset, output_path)
7782

7883

7984
def main():
80-
# merge_sgns(cochlea="M_AMD_N180_L", name_a="CR_SGN_v2", name_b="Ntng1_SGN_v2")
85+
merge_sgns(cochlea="M_AMD_N180_L", name_a="CR_SGN_v2", name_b="Ntng1_SGN_v2")
8186
merge_sgns(cochlea="M_AMD_N180_R", name_a="CR_SGN_v2", name_b="Ntng1_SGN_v2")
8287

8388

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import os
2+
from glob import glob
3+
from pathlib import Path
4+
5+
import imageio.v3 as imageio
6+
import numpy as np
7+
import pandas as pd
8+
9+
from skimage.measure import regionprops
10+
from sklearn.linear_model import LogisticRegression
11+
from sklearn.model_selection import train_test_split
12+
from sklearn.metrics import accuracy_score
13+
14+
ROOT_AMD = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/Result_AMD"
15+
ROOT_EK = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/Result_EK"
16+
17+
18+
def load_annotations(pattern):
19+
paths = sorted(glob(pattern))
20+
annotations = [path[len(pattern):] for path in paths]
21+
channels = [annotation.split("_")[0] for annotation in annotations]
22+
return paths, channels
23+
24+
25+
# Get the features and per channel labels from this crop.
26+
def extract_crop_data(crop_table, crop_root):
27+
table = pd.read_csv(crop_table, sep="\t")
28+
prefix = Path(crop_table).stem
29+
30+
# Get the paths to all annotations.
31+
paths_amd, channels_amd = load_annotations(os.path.join(ROOT_AMD, f"positive-negative_{prefix}*"))
32+
paths_ek, channels_ek = load_annotations(os.path.join(ROOT_EK, f"positive-negative_{prefix}*"))
33+
channel_names = list(set(channels_amd))
34+
35+
# Load the segmentation.
36+
seg_path = os.path.join(crop_root, f"{prefix}_PV_SGN_v2.tif")
37+
seg = imageio.imread(seg_path)
38+
39+
# Load the features (= intensity and PV intensity ratios) for both channels.
40+
features = table[
41+
[f"marker_{channel_names[0]}", f"{channel_names[0]}_ratio_PV"] +
42+
[f"marker_{channel_names[1]}", f"{channel_names[1]}_ratio_PV"]
43+
].values
44+
45+
# Load the labels, derived from the annotations.
46+
labels = {channel: None for channel in channel_names}
47+
for channel, path in zip(channels_amd, paths_amd):
48+
data = imageio.imread(path)
49+
props = regionprops(seg, data)
50+
labeling = np.array([prop.max_intensity for prop in props], dtype="int32")
51+
if labels[channel] is None:
52+
labels[channel] = labeling
53+
else:
54+
# Combine labels so that we only keep the labels that agree, set others to zero
55+
# (in order to filter them out later).
56+
prev_labeling = labels[channel]
57+
disagreement = prev_labeling != labeling
58+
labeling[disagreement] = 0
59+
labels[channel] = labeling
60+
61+
return features, labels
62+
63+
64+
def process_cochlea(cochlea):
65+
# The root folders for tables and crop data for this cochlea.
66+
table_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/tables_{cochlea}"
67+
crop_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/{cochlea}"
68+
69+
# Getthe tables for all crops in this cochlea.
70+
tables = sorted(glob(os.path.join(table_root, "*.tsv")))
71+
72+
# Iterate over the crops, load the features and the labels per channel.
73+
features = []
74+
labels = {}
75+
for table in tables:
76+
crop_features, crop_labels = extract_crop_data(table, crop_root)
77+
features.append(crop_features)
78+
# Concatenate the labels per channel.
79+
for channel, labeling in crop_labels.items():
80+
if channel in labels:
81+
labels[channel] = np.concatenate([labels[channel], labeling], axis=0)
82+
else:
83+
labels[channel] = labeling
84+
features = np.concatenate(features, axis=0)
85+
86+
# Train and evaluate logistic regression per channel.
87+
start, stop = 0, 2
88+
for channel, labeling in labels.items():
89+
# Exclude labels with value zero.
90+
label_mask = labeling != 0
91+
# Get the features for this channel.
92+
this_features = features[:, start:stop][label_mask]
93+
this_labels = labeling[label_mask]
94+
95+
# Create a train and test split.
96+
train_features, test_features, train_labels, test_labels = train_test_split(
97+
this_features, this_labels, test_size=0.3
98+
)
99+
100+
# Train and evaluate the classifier.
101+
classifier = LogisticRegression(penalty="l2")
102+
classifier.fit(train_features, train_labels)
103+
104+
prediction = classifier.predict(test_features)
105+
accuracy = accuracy_score(test_labels, prediction)
106+
print("Channel:", channel)
107+
print("Accuracy:", accuracy)
108+
109+
start += 2
110+
stop += 2
111+
112+
# Note: we could do some other things here:
113+
# - Train a single classifier for subtype prediction (= 4 classes) using all channels.
114+
# - Use different classifier (e.g. RandomForest); however, accuracy from logistic regression looks fine.
115+
# - To better understand results we could also look at the confusion matrix.
116+
# - A better evaluation would be to train and test on separate blocks.
117+
118+
# The classifier can be saved and loaded with pickle, to apply it to all SGNs in the cochlea later.
119+
120+
121+
def main():
122+
# Process a cochlea by:
123+
# - Extracting the features (intensities and intensity ratios) and labels for each crop.
124+
# - Training a classifier based on the labels and evaluating it.
125+
process_cochlea("MLR99L")
126+
127+
128+
if __name__ == "__main__":
129+
main()

0 commit comments

Comments
 (0)