33from glob import glob
44from subprocess import run
55
6+ import matplotlib .pyplot as plt
67import pandas as pd
8+ from skimage .filters import threshold_otsu
79
810from flamingo_tools .s3_utils import BUCKET_NAME , create_s3_target , get_s3_path
911from flamingo_tools .measurements import compute_object_measures
4547 },
4648}
4749
50+ PLOT_OUT = "./subtype_plots"
51+
4852
4953def check_processing_status ():
5054 s3 = create_s3_target ()
@@ -206,15 +210,83 @@ def compile_data_for_subtype_analysis():
206210 output_table .to_csv (out_path , sep = "\t " , index = False )
207211
208212
209- def _plot_histogram (table , column , name , show_plots ):
213+ def _plot_histogram (table , column , name , show_plots , subtype = None ):
210214 data = table [column ].values
215+ threshold = threshold_otsu (data )
216+
217+ fig , ax = plt .subplots (1 )
218+ ax .hist (data , bins = 24 )
219+ ax .axvline (x = threshold , color = 'red' , linestyle = '--' )
220+ ax .set_title (f"{ name } \n threshold: { threshold } " )
221+
222+ if show_plots :
223+ plt .show ()
224+ else :
225+ os .makedirs (PLOT_OUT , exist_ok = True )
226+ plt .savefig (f"{ PLOT_OUT } /{ name } .png" )
227+
228+ if subtype is not None :
229+ subtype_classification = [None if datum < threshold else subtype for datum in data ]
230+ return subtype_classification
231+
232+
233+ def _plot_2d (ratios , name , show_plots , classification = None ):
234+ fig , ax = plt .subplots (1 )
235+ assert len (ratios ) == 2
236+ keys = list (ratios .keys ())
237+ k1 , k2 = keys
238+
239+ if classification is None :
240+ ax .scatter (ratios [k1 , k2 ])
241+
242+ else :
243+ def _combine (a , b ):
244+ if a is None and b is None :
245+ return None
246+ elif a is None and b is not None :
247+ return b
248+ elif a is not None and b is None :
249+ return a
250+ else :
251+ return f"{ a } -{ b } "
252+
253+ classification = [cls for cls in classification if cls is not None ]
254+ labels = classification [0 ].copy ()
255+ for cls in classification [1 :]:
256+ if cls is None :
257+ continue
258+ labels = [_combine (a , b ) for a , b in zip (labels , cls )]
259+
260+ unique_labels = set (ll for ll in labels if ll is not None )
261+ all_colors = ["red" , "blue" , "orange" , "yellow" ]
262+ colors = {ll : color for ll , color in zip (unique_labels , all_colors [:len (unique_labels )])}
263+
264+ for lbl in unique_labels :
265+ mask = [ll == lbl for ll in labels ]
266+ ax .scatter (
267+ [ratios [k1 ][i ] for i in range (len (labels )) if mask [i ]],
268+ [ratios [k2 ][i ] for i in range (len (labels )) if mask [i ]],
269+ c = colors [lbl ], label = lbl
270+ )
271+
272+ mask_none = [ll is None for ll in labels ]
273+ ax .scatter (
274+ [ratios [k1 ][i ] for i in range (len (labels )) if mask_none [i ]],
275+ [ratios [k2 ][i ] for i in range (len (labels )) if mask_none [i ]],
276+ facecolors = "none" , edgecolors = "black" , label = "None"
277+ )
211278
212- # TODO determine automatic threshold
279+ ax .legend ()
280+
281+ ax .set_xlabel (k1 )
282+ ax .set_ylabel (k2 )
283+ ax .set_title (name )
213284
214285 if show_plots :
215- pass
286+ plt . show ()
216287 else :
217- pass
288+ os .makedirs (PLOT_OUT , exist_ok = True )
289+ plt .savefig (f"{ PLOT_OUT } /{ name } .png" )
218290
219291
220292# TODO enable over-writing by manual thresholds
@@ -229,24 +301,29 @@ def analyze_subtype_data(show_plots=True):
229301 assert channels [0 ] == reference_channel
230302
231303 tab = pd .read_csv (ff , sep = "\t " )
232- breakpoint ()
233304
234305 # 1.) Plot simple intensity histograms, including otsu threshold.
235306 for chan in channels :
236307 column = f"{ chan } _median"
237- name = f"{ cochlea } _{ chan } _histogram.png "
308+ name = f"{ cochlea } _{ chan } _histogram"
238309 _plot_histogram (tab , column , name , show_plots )
239310
240311 # 2.) Plot ratio histograms, including otsu threshold.
241- ratios = {}
242312 # TODO ratio based classification and overlay in 2d plot?
313+ ratios = {}
314+ subtype_classification = []
243315 for chan in channels [1 :]:
244- column = f"{ chan } _median_ratio_{ reference_channel } "
245- name = f"{ cochlea } _{ chan } _histogram_ratio_{ reference_channel } .png"
246- _plot_histogram (tab , column , name , show_plots )
316+ column = f"{ chan } _ratio_{ reference_channel } "
317+ name = f"{ cochlea } _{ chan } _histogram_ratio_{ reference_channel } "
318+ classification = _plot_histogram (
319+ tab , column , name , subtype = CHANNEL_TO_TYPE .get (chan , None ), show_plots = show_plots
320+ )
321+ subtype_classification .append (classification )
247322 ratios [f"{ chan } _{ reference_channel } " ] = tab [column ].values
248323
249324 # 3.) Plot 2D space of ratios.
325+ name = f"{ cochlea } _2d"
326+ _plot_2d (ratios , name , show_plots , classification = subtype_classification )
250327
251328
252329# General notes:
@@ -256,12 +333,12 @@ def analyze_subtype_data(show_plots=True):
256333# M_AMD_N62_L: PV signal and segmentation look good.
257334# M_AMD_N180_R: Need SGN segmentation based on CR.
258335def main ():
259- missing_tables = check_processing_status ()
260- require_missing_tables (missing_tables )
336+ # missing_tables = check_processing_status()
337+ # require_missing_tables(missing_tables)
261338
262339 # compile_data_for_subtype_analysis()
263340
264- # analyze_subtype_data()
341+ analyze_subtype_data (show_plots = False )
265342
266343
267344if __name__ == "__main__" :
0 commit comments