Skip to content

Commit 907572b

Browse files
committed
Initial plots for SGN subtypes
1 parent ec49fc8 commit 907572b

File tree

1 file changed

+164
-53
lines changed

1 file changed

+164
-53
lines changed

scripts/measurements/sgn_subtypes.py

Lines changed: 164 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,82 @@
44
from glob import glob
55

66
import matplotlib.pyplot as plt
7+
import numpy as np
78
import pandas as pd
89
from skimage.filters import threshold_otsu
910

1011
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target, get_s3_path
1112
from flamingo_tools.measurements import compute_object_measures
1213

13-
sys.path.append("../figures")
14+
15+
# Define the animal specific octave bands.
16+
def _get_mapping(animal):
17+
if animal == "mouse":
18+
bin_edges = [0, 2, 4, 8, 16, 32, 64, np.inf]
19+
bin_labels = [
20+
"<2", "2–4", "4–8", "8–16", "16–32", "32–64", ">64"
21+
]
22+
elif animal == "gerbil":
23+
bin_edges = [0, 0.5, 1, 2, 4, 8, 16, 32, np.inf]
24+
bin_labels = [
25+
"<0.5", "0.5–1", "1–2", "2–4", "4–8", "8–16", "16–32", ">32"
26+
]
27+
else:
28+
raise ValueError
29+
assert len(bin_edges) == len(bin_labels) + 1
30+
return bin_edges, bin_labels
31+
32+
33+
def frequency_mapping(frequencies, values, animal="mouse", transduction_efficiency=False):
34+
# Get the mapping of frequencies to octave bands for the given species.
35+
bin_edges, bin_labels = _get_mapping(animal)
36+
37+
# Construct the data frame with octave bands.
38+
df = pd.DataFrame({"freq_khz": frequencies, "value": values})
39+
df["octave_band"] = pd.cut(
40+
df["freq_khz"], bins=bin_edges, labels=bin_labels, right=False
41+
)
42+
43+
if transduction_efficiency: # We compute the transduction efficiency per band.
44+
num_pos = df[df["value"] == 1].groupby("octave_band", observed=False).size()
45+
num_tot = df[df["value"].isin([1, 2])].groupby("octave_band", observed=False).size()
46+
value_by_band = (num_pos / num_tot).reindex(bin_labels)
47+
else: # Otherwise, aggregate the values over the octave band using the mean.
48+
value_by_band = (
49+
df.groupby("octave_band", observed=True)["value"]
50+
.sum()
51+
.reindex(bin_labels) # keep octave order even if a bin is empty
52+
)
53+
return value_by_band
54+
1455

1556
# Map from cochlea names to channels
1657
COCHLEAE_FOR_SUBTYPES = {
1758
"M_LR_000099_L": ["PV", "Calb1", "Lypd1"],
18-
"M_LR_000214_L": ["PV", "CR", "Calb1"],
19-
"M_AMD_N62_L": ["PV", "CR", "Calb1"],
20-
"M_AMD_N180_R": ["CR", "Ntng1", "CTBP2"],
21-
"M_AMD_N180_L": ["CR", "Ntng1", "Lypd1"],
59+
# "M_LR_000214_L": ["PV", "CR", "Calb1"],
60+
# "M_AMD_N62_L": ["PV", "CR", "Calb1"],
61+
# "M_AMD_N180_R": ["CR", "Ntng1", "CTBP2"],
62+
# "M_AMD_N180_L": ["CR", "Ntng1", "Lypd1"],
2263
"M_LR_000184_R": ["PV", "Prph"],
2364
"M_LR_000184_L": ["PV", "Prph"],
2465
# Mutant / some stuff is weird.
2566
# "M_AMD_Runx1_L": ["PV", "Lypd1", "Calb1"],
2667
# This one still has to be stitched:
2768
# "M_LR_000184_R": {"PV", "Prph"},
2869
}
70+
71+
COCHLEAE = {
72+
"M_LR_000184_L": {"seg_data": "SGN_v2", "subtype": ["Prph"], "output_seg": "SGN_v2b"},
73+
"M_LR_000184_R": {"seg_data": "SGN_v2", "subtype": ["Prph"], "output_seg": "SGN_v2b"},
74+
"M_LR_000099_L": {"seg_data": "PV_SGN_v2", "subtype": ["Calb1", "Lypd1"]},
75+
# "M_LR_000214_L": {"seg_data": "PV_SGN_v2", "subtype": ["Calb1"]},
76+
}
77+
78+
2979
REGULAR_COCHLEAE = [
30-
"M_LR_000099_L", "M_LR_000214_L", "M_AMD_N62_L", "M_LR_000184_R", "M_LR_000184_L"
80+
"M_LR_000099_L", "M_LR_000184_R", "M_LR_000184_L"
3181
]
82+
# "M_LR_000214_L", "M_AMD_N62_L",
3283

3384
# For custom thresholds.
3485
THRESHOLDS = {
@@ -118,32 +169,25 @@ def check_processing_status():
118169
if channels_missing:
119170
print("Missing the expected channels:", channels_missing)
120171

121-
if "SGN_v2" in sources:
172+
if "SGN_v2b" in sources:
173+
print("SGN segmentation is present with name SGN_v2b")
174+
table_folder = "tables/SGN_v2b"
175+
elif "SGN_v2" in sources:
122176
print("SGN segmentation is present with name SGN_v2")
123-
seg_name = "SGN-v2"
124177
table_folder = "tables/SGN_v2"
125178
elif "PV_SGN_v2" in sources:
126179
print("SGN segmentation is present with name PV_SGN_v2")
127-
seg_name = "PV-SGN-v2"
128180
table_folder = "tables/PV_SGN_v2"
129181
elif "CR_SGN_v2" in sources:
130182
print("SGN segmentation is present with name CR_SGN_v2")
131-
seg_name = "CR-SGN-v2"
132183
table_folder = "tables/CR_SGN_v2"
133184
else:
134185
print("SGN segmentation is MISSING")
135186
print()
136187
continue
137188

138189
# Check which tables we have.
139-
if cochlea == "M_AMD_N180_L": # we need all intensity measures here
140-
seg_names = ["CR-SGN-v2", "Ntng1-SGN-v2", "Lypd1-SGN-v2"]
141-
expected_tables = [f"{chan}_{sname}_object-measures.tsv" for chan in channels for sname in seg_names]
142-
elif cochlea == "M_AMD_N180_R":
143-
seg_names = ["CR-SGN-v2", "Ntng1-SGN-v2"]
144-
expected_tables = [f"{chan}_{sname}_object-measures.tsv" for chan in channels for sname in seg_names]
145-
else:
146-
expected_tables = [f"{chan}_{seg_name}_object-measures.tsv" for chan in channels]
190+
expected_tables = ["default.tsv"]
147191

148192
tables = s3.ls(os.path.join(BUCKET_NAME, cochlea, table_folder))
149193
tables = [os.path.basename(tab) for tab in tables]
@@ -226,8 +270,12 @@ def compile_data_for_subtype_analysis():
226270
seg_source = sources[seg_name]
227271
except KeyError as e:
228272
if seg_name == "PV_SGN_v2":
229-
seg_source = sources["SGN_v2"]
230-
seg_name = "SGN_v2"
273+
if "output_seg" in list(COCHLEAE[cochlea].keys()):
274+
seg_source = sources[COCHLEAE[cochlea]["output_seg"]]
275+
seg_name = COCHLEAE[cochlea]["output_seg"]
276+
else:
277+
seg_source = sources[COCHLEAE[cochlea]["output_seg"]]
278+
seg_name = COCHLEAE[cochlea]["output_seg"]
231279
else:
232280
raise e
233281
table_folder = os.path.join(
@@ -238,6 +286,7 @@ def compile_data_for_subtype_analysis():
238286

239287
# Get the SGNs in the main component
240288
table = table[table.component_labels == 1]
289+
print("Number of SGNs", len(table))
241290
valid_sgns = table.label_id
242291

243292
output_table = {"label_id": table.label_id.values, "frequency[kHz]": table["frequency[kHz]"]}
@@ -246,6 +295,14 @@ def compile_data_for_subtype_analysis():
246295
reference_intensity = None
247296
for channel in channels:
248297
# Load the intensity table, prefer local.
298+
table_folder = os.path.join(
299+
BUCKET_NAME, cochlea, seg_source["segmentation"]["tableData"]["tsv"]["relativePath"]
300+
)
301+
table_content = s3.open(os.path.join(table_folder, "default.tsv"), mode="rb")
302+
table = pd.read_csv(table_content, sep="\t")
303+
table = table[table.component_labels == 1]
304+
305+
# local
249306
table_name = f"{channel}_{seg_name.replace('_', '-')}_object-measures.tsv"
250307
intensity_path = os.path.join("object_measurements", cochlea, table_name)
251308

@@ -257,7 +314,6 @@ def compile_data_for_subtype_analysis():
257314

258315
intensities = pd.read_csv(table_content, sep="\t")
259316
intensities = intensities[intensities.label_id.isin(valid_sgns)]
260-
261317
assert len(table) == len(intensities)
262318
assert (intensities.label_id.values == table.label_id.values).all()
263319

@@ -341,9 +397,8 @@ def _plot_2d(ratios, name, show_plots, classification=None, colors=None):
341397

342398

343399
def _plot_tonotopic_mapping(freq, classification, name, colors, show_plots):
344-
from util import frequency_mapping
345400

346-
frequency_mapped = frequency_mapping(freq, classification, categorical=True)
401+
frequency_mapped = frequency_mapping(freq, classification)
347402
result = next(iter(frequency_mapped.values()))
348403
bin_labels = pd.unique(result["octave_band"])
349404
band_to_x = {band: i for i, band in enumerate(bin_labels)}
@@ -379,30 +434,46 @@ def combined_analysis(results, show_plots):
379434
# Create the tonotopic mapping.
380435
#
381436
summary = {}
437+
colors = {}
382438
for cochlea, result in results.items():
383439
if cochlea == "M_LR_000214_L": # One of the signals cannot be analyzed.
384440
continue
385-
mapping = result["tonotopic_mapping"]
386-
summary[cochlea] = mapping
441+
classification = result["classification"]
442+
frequencies = result["frequencies"]
443+
# get categories
444+
cats = list(set([c[:c.find(" (")] for c in classification]))
445+
cats.sort()
446+
447+
dic = {}
448+
for c in cats:
449+
sub_freq = [frequencies[i] for i in range(len(classification))
450+
if classification[i][:classification[i].find(" (")] == c]
451+
mapping = frequency_mapping(sub_freq, [1 for _ in range(len(sub_freq))])
452+
mapping = mapping.astype('float32')
453+
dic[c] = mapping
454+
bin_labels = pd.unique(mapping.index)
455+
456+
if c not in colors:
457+
current_colors = list(colors.values())
458+
next_color = ALL_COLORS[len(current_colors)]
459+
colors[c] = next_color
387460

388-
colors = {}
461+
for bin in bin_labels:
462+
total = sum([dic[key][bin] for key in dic.keys()])
463+
for key in dic.keys():
464+
dic[key][bin] = float(dic[key][bin] / total)
465+
466+
summary[cochlea] = dic
389467

390468
fig, axes = plt.subplots(len(summary), sharey=True, figsize=(8, 8))
391-
for i, (cochlea, frequency_mapped) in enumerate(summary.items()):
469+
for i, (cochlea, dic) in enumerate(summary.items()):
470+
types = list(dic.keys())
392471
ax = axes[i]
393-
394-
result = next(iter(frequency_mapped.values()))
395-
bin_labels = pd.unique(result["octave_band"])
396-
band_to_x = {band: i for i, band in enumerate(bin_labels)}
397-
x_positions = result["octave_band"].map(band_to_x)
398-
399-
for cat, vals in frequency_mapped.items():
400-
values = vals.value
401-
cat = cat[:cat.find(" (")]
402-
if cat not in colors:
403-
current_colors = list(colors.values())
404-
next_color = ALL_COLORS[len(current_colors)]
405-
colors[cat] = next_color
472+
for cat in types:
473+
frequency_mapped = dic[cat]
474+
bin_labels = pd.unique(frequency_mapped.index)
475+
x_positions = [i for i in range(len(bin_labels))]
476+
values = frequency_mapped.values
406477
ax.scatter(x_positions, values, label=cat, color=colors[cat])
407478

408479
main_ticks = range(len(bin_labels))
@@ -462,19 +533,53 @@ def analyze_subtype_data_regular(show_plots=True):
462533
global PLOT_OUT, COLORS # noqa
463534
PLOT_OUT = "subtype_plots/regular_mice"
464535

536+
s3 = create_s3_target()
537+
465538
files = sorted(glob("./subtype_analysis/*.tsv"))
466539
results = {}
467540

468541
for ff in files:
469542
cochlea = os.path.basename(ff)[:-len("_subtype_analysis.tsv")]
470543
if cochlea not in REGULAR_COCHLEAE:
471544
continue
472-
print(cochlea)
473545
channels = COCHLEAE_FOR_SUBTYPES[cochlea]
474546

475547
reference_channel = "PV"
476548
assert channels[0] == reference_channel
477549

550+
seg_name = "PV_SGN_v2"
551+
552+
content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8")
553+
info = json.loads(content.read())
554+
sources = info["sources"]
555+
556+
# Load the segmentation table.
557+
try:
558+
seg_source = sources[seg_name]
559+
except KeyError as e:
560+
if seg_name == "PV_SGN_v2":
561+
if "output_seg" in list(COCHLEAE[cochlea].keys()):
562+
seg_source = sources[COCHLEAE[cochlea]["output_seg"]]
563+
seg_name = COCHLEAE[cochlea]["output_seg"]
564+
else:
565+
seg_source = sources[COCHLEAE[cochlea]["output_seg"]]
566+
seg_name = COCHLEAE[cochlea]["output_seg"]
567+
else:
568+
raise e
569+
table_folder = os.path.join(
570+
BUCKET_NAME, cochlea, seg_source["segmentation"]["tableData"]["tsv"]["relativePath"]
571+
)
572+
table_content = s3.open(os.path.join(table_folder, "default.tsv"), mode="rb")
573+
table = pd.read_csv(table_content, sep="\t")
574+
table = table[table.component_labels == 1]
575+
576+
print(f"Length of table before filtering: {len(table)}")
577+
# filter subtype table
578+
for chan in channels[1:]:
579+
column = f"marker_{chan}"
580+
table = table.loc[table[column].isin([1,2])]
581+
print(f"Length of table after filtering channel {chan}: {len(table)}")
582+
478583
tab = pd.read_csv(ff, sep="\t")
479584

480585
# 1.) Plot simple intensity histograms, including otsu threshold.
@@ -488,12 +593,18 @@ def analyze_subtype_data_regular(show_plots=True):
488593
classification = []
489594
for chan in channels[1:]:
490595
column = f"{chan}_ratio_{reference_channel}"
491-
name = f"{cochlea}_{chan}_histogram_ratio_{reference_channel}"
492-
chan_classification = _plot_histogram(
493-
tab, column, name, class_names=[f"{chan}-", f"{chan}+"], show_plots=show_plots
494-
)
596+
# e.g. Calb1_ratio_PV
597+
column = f"marker_{chan}"
598+
subset = table.loc[table[column].isin([1, 2])]
599+
marker = list(subset[column])
600+
chan_classification = []
601+
for m in marker:
602+
if m == 1:
603+
chan_classification.append(f"{chan}+")
604+
elif m == 2:
605+
chan_classification.append(f"{chan}-")
495606
classification.append(chan_classification)
496-
ratios[f"{chan}_{reference_channel}"] = tab[column].values
607+
ratios[f"{chan}_{reference_channel}"] = table[column].values
497608

498609
# Unify the classification and assign colors
499610
assert len(classification) in (1, 2)
@@ -521,19 +632,18 @@ def analyze_subtype_data_regular(show_plots=True):
521632
COLORS[label] = ALL_COLORS[0]
522633

523634
# 3.) Plot tonotopic mapping.
524-
freq = tab["frequency[kHz]"].values
635+
freq = table["frequency[kHz]"].values
525636
assert len(freq) == len(classification)
526-
name = f"{cochlea}_tonotopic_mapping"
527-
tonotopic_mapping = _plot_tonotopic_mapping(
528-
freq, classification, name=name, colors=COLORS, show_plots=show_plots
529-
)
637+
#tonotopic_mapping = _plot_tonotopic_mapping(
638+
# freq, classification, name=name, colors=COLORS, show_plots=show_plots
639+
#)
530640

531641
# 4.) Plot 2D space of ratios.
532642
if show_2d:
533643
name = f"{cochlea}_2d"
534644
_plot_2d(ratios, name, show_plots, classification=classification, colors=COLORS)
535645

536-
results[cochlea] = {"classification": classification, "tonotopic_mapping": tonotopic_mapping}
646+
results[cochlea] = {"classification": classification, "frequencies": freq}
537647

538648
combined_analysis(results, show_plots=show_plots)
539649

@@ -594,7 +704,8 @@ def export_for_annotation():
594704
def main():
595705
# These scripts are for computing the intensity tables etc.
596706
missing_tables = check_processing_status()
597-
require_missing_tables(missing_tables)
707+
print("missing tables", missing_tables)
708+
# require_missing_tables(missing_tables)
598709
compile_data_for_subtype_analysis()
599710

600711
# This script is for exporting the tables for annotation in MoBIE.

0 commit comments

Comments
 (0)