Skip to content

Commit 71b46ce

Browse files
Update subtype analysis
1 parent 06ac936 commit 71b46ce

File tree

1 file changed

+90
-13
lines changed

1 file changed

+90
-13
lines changed

scripts/measurements/sgn_subtypes.py

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from glob import glob
44
from subprocess import run
55

6+
import matplotlib.pyplot as plt
67
import pandas as pd
8+
from skimage.filters import threshold_otsu
79

810
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target, get_s3_path
911
from flamingo_tools.measurements import compute_object_measures
@@ -45,6 +47,8 @@
4547
},
4648
}
4749

50+
PLOT_OUT = "./subtype_plots"
51+
4852

4953
def check_processing_status():
5054
s3 = create_s3_target()
@@ -206,15 +210,83 @@ def compile_data_for_subtype_analysis():
206210
output_table.to_csv(out_path, sep="\t", index=False)
207211

208212

209-
def _plot_histogram(table, column, name, show_plots):
213+
def _plot_histogram(table, column, name, show_plots, subtype=None):
210214
data = table[column].values
215+
threshold = threshold_otsu(data)
216+
217+
fig, ax = plt.subplots(1)
218+
ax.hist(data, bins=24)
219+
ax.axvline(x=threshold, color='red', linestyle='--')
220+
ax.set_title(f"{name}\n threshold: {threshold}")
221+
222+
if show_plots:
223+
plt.show()
224+
else:
225+
os.makedirs(PLOT_OUT, exist_ok=True)
226+
plt.savefig(f"{PLOT_OUT}/{name}.png")
227+
228+
if subtype is not None:
229+
subtype_classification = [None if datum < threshold else subtype for datum in data]
230+
return subtype_classification
231+
232+
233+
def _plot_2d(ratios, name, show_plots, classification=None):
234+
fig, ax = plt.subplots(1)
235+
assert len(ratios) == 2
236+
keys = list(ratios.keys())
237+
k1, k2 = keys
238+
239+
if classification is None:
240+
ax.scatter(ratios[k1, k2])
241+
242+
else:
243+
def _combine(a, b):
244+
if a is None and b is None:
245+
return None
246+
elif a is None and b is not None:
247+
return b
248+
elif a is not None and b is None:
249+
return a
250+
else:
251+
return f"{a}-{b}"
252+
253+
classification = [cls for cls in classification if cls is not None]
254+
labels = classification[0].copy()
255+
for cls in classification[1:]:
256+
if cls is None:
257+
continue
258+
labels = [_combine(a, b) for a, b in zip(labels, cls)]
259+
260+
unique_labels = set(ll for ll in labels if ll is not None)
261+
all_colors = ["red", "blue", "orange", "yellow"]
262+
colors = {ll: color for ll, color in zip(unique_labels, all_colors[:len(unique_labels)])}
263+
264+
for lbl in unique_labels:
265+
mask = [ll == lbl for ll in labels]
266+
ax.scatter(
267+
[ratios[k1][i] for i in range(len(labels)) if mask[i]],
268+
[ratios[k2][i] for i in range(len(labels)) if mask[i]],
269+
c=colors[lbl], label=lbl
270+
)
271+
272+
mask_none = [ll is None for ll in labels]
273+
ax.scatter(
274+
[ratios[k1][i] for i in range(len(labels)) if mask_none[i]],
275+
[ratios[k2][i] for i in range(len(labels)) if mask_none[i]],
276+
facecolors="none", edgecolors="black", label="None"
277+
)
211278

212-
# TODO determine automatic threshold
279+
ax.legend()
280+
281+
ax.set_xlabel(k1)
282+
ax.set_ylabel(k2)
283+
ax.set_title(name)
213284

214285
if show_plots:
215-
pass
286+
plt.show()
216287
else:
217-
pass
288+
os.makedirs(PLOT_OUT, exist_ok=True)
289+
plt.savefig(f"{PLOT_OUT}/{name}.png")
218290

219291

220292
# TODO enable over-writing by manual thresholds
@@ -229,24 +301,29 @@ def analyze_subtype_data(show_plots=True):
229301
assert channels[0] == reference_channel
230302

231303
tab = pd.read_csv(ff, sep="\t")
232-
breakpoint()
233304

234305
# 1.) Plot simple intensity histograms, including otsu threshold.
235306
for chan in channels:
236307
column = f"{chan}_median"
237-
name = f"{cochlea}_{chan}_histogram.png"
308+
name = f"{cochlea}_{chan}_histogram"
238309
_plot_histogram(tab, column, name, show_plots)
239310

240311
# 2.) Plot ratio histograms, including otsu threshold.
241-
ratios = {}
242312
# TODO ratio based classification and overlay in 2d plot?
313+
ratios = {}
314+
subtype_classification = []
243315
for chan in channels[1:]:
244-
column = f"{chan}_median_ratio_{reference_channel}"
245-
name = f"{cochlea}_{chan}_histogram_ratio_{reference_channel}.png"
246-
_plot_histogram(tab, column, name, show_plots)
316+
column = f"{chan}_ratio_{reference_channel}"
317+
name = f"{cochlea}_{chan}_histogram_ratio_{reference_channel}"
318+
classification = _plot_histogram(
319+
tab, column, name, subtype=CHANNEL_TO_TYPE.get(chan, None), show_plots=show_plots
320+
)
321+
subtype_classification.append(classification)
247322
ratios[f"{chan}_{reference_channel}"] = tab[column].values
248323

249324
# 3.) Plot 2D space of ratios.
325+
name = f"{cochlea}_2d"
326+
_plot_2d(ratios, name, show_plots, classification=subtype_classification)
250327

251328

252329
# General notes:
@@ -256,12 +333,12 @@ def analyze_subtype_data(show_plots=True):
256333
# M_AMD_N62_L: PV signal and segmentation look good.
257334
# M_AMD_N180_R: Need SGN segmentation based on CR.
258335
def main():
259-
missing_tables = check_processing_status()
260-
require_missing_tables(missing_tables)
336+
# missing_tables = check_processing_status()
337+
# require_missing_tables(missing_tables)
261338

262339
# compile_data_for_subtype_analysis()
263340

264-
# analyze_subtype_data()
341+
analyze_subtype_data(show_plots=False)
265342

266343

267344
if __name__ == "__main__":

0 commit comments

Comments
 (0)