Skip to content

Commit 4372cb2

Browse files
committed
Model application
1 parent 97dfa85 commit 4372cb2

File tree

1 file changed

+97
-4
lines changed

1 file changed

+97
-4
lines changed

scripts/measurements/subtype_prediction.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
import os
2+
import sys
23
from glob import glob
34
from pathlib import Path
45

56
import imageio.v3 as imageio
67
import numpy as np
78
import pandas as pd
9+
import pickle
810

911
from skimage.measure import regionprops
1012
from sklearn.linear_model import LogisticRegression
13+
from flamingo_tools.s3_utils import get_s3_path
1114
from sklearn.model_selection import train_test_split
1215
from sklearn.metrics import accuracy_score
1316

1417
ROOT_AMD = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/Result_AMD"
1518
ROOT_EK = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/Result_EK"
1619

20+
COCHLEA_DICT = {
21+
"MLR99L": {"cochlea_long": "M_LR_000099_L", "seg_name": "PV_SGN_v2"}
22+
}
23+
1724

1825
def load_annotations(pattern):
1926
paths = sorted(glob(pattern))
@@ -23,17 +30,35 @@ def load_annotations(pattern):
2330

2431

2532
# Get the features and per channel labels from this crop.
26-
def extract_crop_data(crop_table, crop_root):
33+
def extract_crop_data(cochlea, crop_table, crop_root):
2734
table = pd.read_csv(crop_table, sep="\t")
2835
prefix = Path(crop_table).stem
2936

3037
# Get the paths to all annotations.
3138
paths_amd, channels_amd = load_annotations(os.path.join(ROOT_AMD, f"positive-negative_{prefix}*"))
3239
paths_ek, channels_ek = load_annotations(os.path.join(ROOT_EK, f"positive-negative_{prefix}*"))
3340
channel_names = list(set(channels_amd))
41+
channel_names.sort()
42+
43+
cochlea_long = COCHLEA_DICT[cochlea]["cochlea_long"]
44+
seg_name = COCHLEA_DICT[cochlea]["seg_name"]
45+
46+
for channel in channel_names:
47+
s3_path = f"{cochlea_long}/tables/{seg_name}/{channel}_{"-".join(seg_name.split("_"))}_object-measures.tsv"
48+
tsv_path, fs = get_s3_path(s3_path)
49+
with fs.open(tsv_path, 'r') as f:
50+
table_measure = pd.read_csv(f, sep="\t")
51+
52+
table = table.merge(
53+
table_measure[["label_id", "median"]],
54+
on="label_id",
55+
how="left"
56+
)
57+
# Rename the merged column
58+
table.rename(columns={"median": f"intensity_{channel}"}, inplace=True)
3459

3560
# Load the segmentation.
36-
seg_path = os.path.join(crop_root, f"{prefix}_PV_SGN_v2.tif")
61+
seg_path = os.path.join(crop_root, f"{prefix}_{seg_name}.tif")
3762
seg = imageio.imread(seg_path)
3863

3964
# Load the features (= intensity and PV intensity ratios) for both channels.
@@ -44,6 +69,9 @@ def extract_crop_data(crop_table, crop_root):
4469

4570
# Load the labels, derived from the annotations.
4671
labels = {channel: None for channel in channel_names}
72+
# total_channels = channels_amd + channels_ek
73+
# total_paths = paths_amd + paths_ek
74+
4775
for channel, path in zip(channels_amd, paths_amd):
4876
data = imageio.imread(path)
4977
props = regionprops(seg, data)
@@ -65,6 +93,7 @@ def process_cochlea(cochlea):
6593
# The root folders for tables and crop data for this cochlea.
6694
table_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/tables_{cochlea}"
6795
crop_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/{cochlea}"
96+
model_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/classifier/{cochlea}"
6897

6998
# Getthe tables for all crops in this cochlea.
7099
tables = sorted(glob(os.path.join(table_root, "*.tsv")))
@@ -73,7 +102,7 @@ def process_cochlea(cochlea):
73102
features = []
74103
labels = {}
75104
for table in tables:
76-
crop_features, crop_labels = extract_crop_data(table, crop_root)
105+
crop_features, crop_labels = extract_crop_data(cochlea, table, crop_root)
77106
features.append(crop_features)
78107
# Concatenate the labels per channel.
79108
for channel, labeling in crop_labels.items():
@@ -92,6 +121,10 @@ def process_cochlea(cochlea):
92121
this_features = features[:, start:stop][label_mask]
93122
this_labels = labeling[label_mask]
94123

124+
labels = list(set(this_labels))
125+
for l in labels:
126+
print(f"label {l} occurences: {list(this_labels).count(l)}")
127+
95128
# Create a train and test split.
96129
train_features, test_features, train_labels, test_labels = train_test_split(
97130
this_features, this_labels, test_size=0.3
@@ -102,9 +135,21 @@ def process_cochlea(cochlea):
102135
classifier.fit(train_features, train_labels)
103136

104137
prediction = classifier.predict(test_features)
138+
model_path = os.path.join(model_root, f"logistic_{channel}.pkl")
139+
with open(model_path, "wb") as f:
140+
pickle.dump(classifier, f)
141+
105142
accuracy = accuracy_score(test_labels, prediction)
106143
print("Channel:", channel)
107144
print("Accuracy:", accuracy)
145+
for index in [1, 2]:
146+
label_list = []
147+
pred_list = []
148+
for l, p in zip(test_labels, prediction):
149+
if l == index:
150+
label_list.append(l)
151+
pred_list.append(p)
152+
print(f"Accuracy label {index}: {accuracy_score(np.array(label_list), np.array(pred_list))}")
108153

109154
start += 2
110155
stop += 2
@@ -118,11 +163,59 @@ def process_cochlea(cochlea):
118163
# The classifier can be saved and loaded with pickle, to apply it to all SGNs in the cochlea later.
119164

120165

166+
def apply_model(cochlea):
167+
model_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/classifier/{cochlea}"
168+
models = [entry.path for entry in os.scandir(model_root) if ".pkl" in entry.name]
169+
170+
cochlea_long = COCHLEA_DICT[cochlea]["cochlea_long"]
171+
seg_name = COCHLEA_DICT[cochlea]["seg_name"]
172+
173+
s3_path = os.path.join(f"{cochlea_long}", "tables", f"{seg_name}", "default.tsv")
174+
tsv_path, fs = get_s3_path(s3_path)
175+
with fs.open(tsv_path, 'r') as f:
176+
table = pd.read_csv(f, sep="\t")
177+
178+
for model_path in models:
179+
channel = os.path.basename(model_path).split(".pkl")[0].split("_")[1]
180+
181+
s3_path = f"{cochlea_long}/tables/{seg_name}/{channel}_{"-".join(seg_name.split("_"))}_object-measures.tsv"
182+
tsv_path, fs = get_s3_path(s3_path)
183+
with fs.open(tsv_path, 'r') as f:
184+
table_measure = pd.read_csv(f, sep="\t")
185+
186+
table = table.merge(
187+
table_measure[["label_id", "median"]],
188+
on="label_id",
189+
how="left"
190+
)
191+
table.rename(columns={"median": f"intensity_{channel}"}, inplace=True)
192+
193+
subset = table.loc[table[f"marker_{channel}"].isin([1, 2])]
194+
features = subset[
195+
[f"marker_{channel}", f"{channel}_ratio_PV"]
196+
].values
197+
198+
with open(model_path, "rb") as f:
199+
classifier = pickle.load(f)
200+
201+
prediction = classifier.predict(features)
202+
# switch prediction to be consistent with markers: 1 - positive, 2 - negative
203+
prediction = [2 if x == 1 else 1 for x in prediction]
204+
205+
table.loc[:, f"classifier_{channel}"] = 0
206+
table.loc[subset.index, f"classifier_{channel}"] = prediction
207+
208+
out_path = os.path.join(model_root, cochlea + ".tsv")
209+
table.to_csv(out_path, sep="\t", index=False)
210+
211+
121212
def main():
122213
# Process a cochlea by:
123214
# - Extracting the features (intensities and intensity ratios) and labels for each crop.
124215
# - Training a classifier based on the labels and evaluating it.
125-
process_cochlea("MLR99L")
216+
cochlea = "MLR99L"
217+
process_cochlea(cochlea)
218+
apply_model(cochlea)
126219

127220

128221
if __name__ == "__main__":

0 commit comments

Comments
 (0)