11import os
2+ import sys
23from glob import glob
34from pathlib import Path
45
56import imageio .v3 as imageio
67import numpy as np
78import pandas as pd
9+ import pickle
810
911from skimage .measure import regionprops
1012from sklearn .linear_model import LogisticRegression
13+ from flamingo_tools .s3_utils import get_s3_path
1114from sklearn .model_selection import train_test_split
1215from sklearn .metrics import accuracy_score
1316
1417ROOT_AMD = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/Result_AMD"
1518ROOT_EK = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/Result_EK"
1619
20+ COCHLEA_DICT = {
21+ "MLR99L" : {"cochlea_long" : "M_LR_000099_L" , "seg_name" : "PV_SGN_v2" }
22+ }
23+
1724
1825def load_annotations (pattern ):
1926 paths = sorted (glob (pattern ))
@@ -23,17 +30,35 @@ def load_annotations(pattern):
2330
2431
2532# Get the features and per channel labels from this crop.
26- def extract_crop_data (crop_table , crop_root ):
33+ def extract_crop_data (cochlea , crop_table , crop_root ):
2734 table = pd .read_csv (crop_table , sep = "\t " )
2835 prefix = Path (crop_table ).stem
2936
3037 # Get the paths to all annotations.
3138 paths_amd , channels_amd = load_annotations (os .path .join (ROOT_AMD , f"positive-negative_{ prefix } *" ))
3239 paths_ek , channels_ek = load_annotations (os .path .join (ROOT_EK , f"positive-negative_{ prefix } *" ))
3340 channel_names = list (set (channels_amd ))
41+ channel_names .sort ()
42+
43+ cochlea_long = COCHLEA_DICT [cochlea ]["cochlea_long" ]
44+ seg_name = COCHLEA_DICT [cochlea ]["seg_name" ]
45+
46+ for channel in channel_names :
47+ s3_path = f"{ cochlea_long } /tables/{ seg_name } /{ channel } _{ "-" .join (seg_name .split ("_" ))} _object-measures.tsv"
48+ tsv_path , fs = get_s3_path (s3_path )
49+ with fs .open (tsv_path , 'r' ) as f :
50+ table_measure = pd .read_csv (f , sep = "\t " )
51+
52+ table = table .merge (
53+ table_measure [["label_id" , "median" ]],
54+ on = "label_id" ,
55+ how = "left"
56+ )
57+ # Rename the merged column
58+ table .rename (columns = {"median" : f"intensity_{ channel } " }, inplace = True )
3459
3560 # Load the segmentation.
36- seg_path = os .path .join (crop_root , f"{ prefix } _PV_SGN_v2 .tif" )
61+ seg_path = os .path .join (crop_root , f"{ prefix } _ { seg_name } .tif" )
3762 seg = imageio .imread (seg_path )
3863
3964 # Load the features (= intensity and PV intensity ratios) for both channels.
@@ -44,6 +69,9 @@ def extract_crop_data(crop_table, crop_root):
4469
4570 # Load the labels, derived from the annotations.
4671 labels = {channel : None for channel in channel_names }
72+ # total_channels = channels_amd + channels_ek
73+ # total_paths = paths_amd + paths_ek
74+
4775 for channel , path in zip (channels_amd , paths_amd ):
4876 data = imageio .imread (path )
4977 props = regionprops (seg , data )
@@ -65,6 +93,7 @@ def process_cochlea(cochlea):
6593 # The root folders for tables and crop data for this cochlea.
6694 table_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/tables_{ cochlea } "
6795 crop_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/{ cochlea } "
96+ model_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/classifier/{ cochlea } "
6897
6998 # Getthe tables for all crops in this cochlea.
7099 tables = sorted (glob (os .path .join (table_root , "*.tsv" )))
@@ -73,7 +102,7 @@ def process_cochlea(cochlea):
73102 features = []
74103 labels = {}
75104 for table in tables :
76- crop_features , crop_labels = extract_crop_data (table , crop_root )
105+ crop_features , crop_labels = extract_crop_data (cochlea , table , crop_root )
77106 features .append (crop_features )
78107 # Concatenate the labels per channel.
79108 for channel , labeling in crop_labels .items ():
@@ -92,6 +121,10 @@ def process_cochlea(cochlea):
92121 this_features = features [:, start :stop ][label_mask ]
93122 this_labels = labeling [label_mask ]
94123
124+ labels = list (set (this_labels ))
125+ for l in labels :
126+ print (f"label { l } occurences: { list (this_labels ).count (l )} " )
127+
95128 # Create a train and test split.
96129 train_features , test_features , train_labels , test_labels = train_test_split (
97130 this_features , this_labels , test_size = 0.3
@@ -102,9 +135,21 @@ def process_cochlea(cochlea):
102135 classifier .fit (train_features , train_labels )
103136
104137 prediction = classifier .predict (test_features )
138+ model_path = os .path .join (model_root , f"logistic_{ channel } .pkl" )
139+ with open (model_path , "wb" ) as f :
140+ pickle .dump (classifier , f )
141+
105142 accuracy = accuracy_score (test_labels , prediction )
106143 print ("Channel:" , channel )
107144 print ("Accuracy:" , accuracy )
145+ for index in [1 , 2 ]:
146+ label_list = []
147+ pred_list = []
148+ for l , p in zip (test_labels , prediction ):
149+ if l == index :
150+ label_list .append (l )
151+ pred_list .append (p )
152+ print (f"Accuracy label { index } : { accuracy_score (np .array (label_list ), np .array (pred_list ))} " )
108153
109154 start += 2
110155 stop += 2
@@ -118,11 +163,59 @@ def process_cochlea(cochlea):
118163 # The classifier can be saved and loaded with pickle, to apply it to all SGNs in the cochlea later.
119164
120165
166+ def apply_model (cochlea ):
167+ model_root = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes/classifier/{ cochlea } "
168+ models = [entry .path for entry in os .scandir (model_root ) if ".pkl" in entry .name ]
169+
170+ cochlea_long = COCHLEA_DICT [cochlea ]["cochlea_long" ]
171+ seg_name = COCHLEA_DICT [cochlea ]["seg_name" ]
172+
173+ s3_path = os .path .join (f"{ cochlea_long } " , "tables" , f"{ seg_name } " , "default.tsv" )
174+ tsv_path , fs = get_s3_path (s3_path )
175+ with fs .open (tsv_path , 'r' ) as f :
176+ table = pd .read_csv (f , sep = "\t " )
177+
178+ for model_path in models :
179+ channel = os .path .basename (model_path ).split (".pkl" )[0 ].split ("_" )[1 ]
180+
181+ s3_path = f"{ cochlea_long } /tables/{ seg_name } /{ channel } _{ "-" .join (seg_name .split ("_" ))} _object-measures.tsv"
182+ tsv_path , fs = get_s3_path (s3_path )
183+ with fs .open (tsv_path , 'r' ) as f :
184+ table_measure = pd .read_csv (f , sep = "\t " )
185+
186+ table = table .merge (
187+ table_measure [["label_id" , "median" ]],
188+ on = "label_id" ,
189+ how = "left"
190+ )
191+ table .rename (columns = {"median" : f"intensity_{ channel } " }, inplace = True )
192+
193+ subset = table .loc [table [f"marker_{ channel } " ].isin ([1 , 2 ])]
194+ features = subset [
195+ [f"marker_{ channel } " , f"{ channel } _ratio_PV" ]
196+ ].values
197+
198+ with open (model_path , "rb" ) as f :
199+ classifier = pickle .load (f )
200+
201+ prediction = classifier .predict (features )
202+ # switch prediction to be consistent with markers: 1 - positive, 2 - negative
203+ prediction = [2 if x == 1 else 1 for x in prediction ]
204+
205+ table .loc [:, f"classifier_{ channel } " ] = 0
206+ table .loc [subset .index , f"classifier_{ channel } " ] = prediction
207+
208+ out_path = os .path .join (model_root , cochlea + ".tsv" )
209+ table .to_csv (out_path , sep = "\t " , index = False )
210+
211+
121212def main ():
122213 # Process a cochlea by:
123214 # - Extracting the features (intensities and intensity ratios) and labels for each crop.
124215 # - Training a classifier based on the labels and evaluating it.
125- process_cochlea ("MLR99L" )
216+ cochlea = "MLR99L"
217+ process_cochlea (cochlea )
218+ apply_model (cochlea )
126219
127220
128221if __name__ == "__main__" :
0 commit comments