Skip to content

Commit e5c1d1a

Browse files
committed
Figure for otof expression efficiency
1 parent 373c6bd commit e5c1d1a

File tree

3 files changed

+243
-6
lines changed

3 files changed

+243
-6
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,8 @@ def run_unet_prediction_slurm(
571571
"""
572572
os.makedirs(output_folder, exist_ok=True)
573573
prediction_instances = int(prediction_instances)
574+
if isinstance(scale, str):
575+
scale = float(scale)
574576
slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
575577

576578
if s3 is not None:

scripts/figures/plot_fig5.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target
1010

1111
from util import SYNAPSE_DIR_ROOT
12-
from plot_fig4 import get_chreef_data
1312

1413
FILE_EXTENSION = "png"
1514
png_dpi = 300
@@ -143,10 +142,6 @@ def fig_05d(save_path, plot):
143142

144143
# TODO would need the new intensity subtracted data here.
145144
# Reference: intensity distributions for ChReef
146-
chreef_data = get_chreef_data()
147-
for cochlea, tab in chreef_data.items():
148-
plt.hist(tab["median"])
149-
plt.show()
150145

151146

152147
def main():

scripts/figures/plot_fig6.py

Lines changed: 241 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,85 @@
11
import argparse
2+
import json
3+
import numpy as np
24
import os
5+
import pickle
36

7+
import matplotlib.ticker as mticker
48
import pandas as pd
59
import matplotlib.pyplot as plt
610

11+
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target
712
from util import literature_reference_values_gerbil, prism_cleanup_axes, prism_style, SYNAPSE_DIR_ROOT
13+
from util import frequency_mapping, custom_formatter_1, export_legend
814

915
FILE_EXTENSION = "png"
1016
png_dpi = 300
1117

1218

19+
INTENSITY_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/tables/LaVision-OTOF" # noqa
20+
21+
# The cochlea for the CHReef analysis.
22+
COCHLEAE_DICT = {
23+
"LaVision-OTOF23R": {"alias": "M02", "component": [4, 18, 7]},
24+
"LaVision-OTOF25R": {"alias": "M03", "component": [1]},
25+
}
26+
27+
COLOR_LEFT = "#8E00DB"
28+
COLOR_RIGHT = "#DB0063"
29+
MARKER_LEFT = "o"
30+
MARKER_RIGHT = "^"
31+
32+
33+
def get_otof_data():
34+
s3 = create_s3_target()
35+
source_name = "IHC_LOWRES-v3"
36+
37+
cache_path = "./otof_data.pkl"
38+
cochleae = [key for key in COCHLEAE_DICT.keys()]
39+
40+
if os.path.exists(cache_path):
41+
with open(cache_path, "rb") as f:
42+
return pickle.load(f)
43+
44+
chreef_data = {}
45+
for cochlea in cochleae:
46+
print("Processsing cochlea:", cochlea)
47+
content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8")
48+
info = json.loads(content.read())
49+
sources = info["sources"]
50+
51+
# Load the seg table and filter the compartments.
52+
source = sources[source_name]["segmentation"]
53+
rel_path = source["tableData"]["tsv"]["relativePath"]
54+
table_content = s3.open(os.path.join(BUCKET_NAME, cochlea, rel_path, "default.tsv"), mode="rb")
55+
table = pd.read_csv(table_content, sep="\t")
56+
print(table.columns)
57+
58+
# May need to be adjusted for some cochleae.
59+
component_labels = COCHLEAE_DICT[cochlea]["component"]
60+
print(cochlea, component_labels)
61+
table = table[table.component_labels.isin(component_labels)]
62+
# The relevant values for analysis.
63+
try:
64+
values = table[["label_id", "length[µm]", "frequency[kHz]", "expression_classification"]]
65+
except KeyError:
66+
print("Could not find the values for", cochlea, "it will be skippped.")
67+
continue
68+
69+
fname = f"{cochlea.replace('_', '-')}_rbOtof_IHC-LOWRES-v3_object-measures.tsv"
70+
intensity_file = os.path.join(INTENSITY_ROOT, fname)
71+
assert os.path.exists(intensity_file), intensity_file
72+
intensity_table = pd.read_csv(intensity_file, sep="\t")
73+
values = values.merge(intensity_table, on="label_id")
74+
75+
chreef_data[cochlea] = values
76+
77+
with open(cache_path, "wb") as f:
78+
pickle.dump(chreef_data, f)
79+
with open(cache_path, "rb") as f:
80+
return pickle.load(f)
81+
82+
1383
# Load the synapse counts for all IHCs from the relevant tables.
1484
def _load_ribbon_synapse_counts():
1585
ihc_version = "ihc_counts_v6"
@@ -22,6 +92,171 @@ def _load_ribbon_synapse_counts():
2292
return syn_counts
2393

2494

95+
def plot_legend_fig06e(figure_dir):
96+
color_dict = {
97+
"O1": "#9C5027",
98+
"O2": "#67279C",
99+
}
100+
save_path = os.path.join(figure_dir, f"fig_06e_legend.{FILE_EXTENSION}")
101+
marker = ["o" for _ in color_dict]
102+
label = list(color_dict.keys())
103+
color = [color_dict[key] for key in color_dict.keys()]
104+
105+
f = lambda m, c: plt.plot([], [], marker=m, color=c, ls="none")[0]
106+
handles = [f(m, c) for (c, m) in zip(color, marker)]
107+
legend = plt.legend(handles, label, loc=3, ncol=2, framealpha=1, frameon=False)
108+
export_legend(legend, save_path)
109+
legend.remove()
110+
plt.close()
111+
112+
113+
def _get_trendline_dict(trend_dict,):
114+
x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys()]
115+
x_dict = {}
116+
for num in range(len(x_sorted[0])):
117+
x_dict[num] = {"pos": num, "values": []}
118+
119+
for s in x_sorted:
120+
for num, pos in enumerate(s):
121+
x_dict[num]["values"].append(pos)
122+
123+
y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys()]
124+
y_dict = {}
125+
for num in range(len(x_sorted[0])):
126+
y_dict[num] = {"pos": num, "values": []}
127+
128+
for num in range(len(x_sorted[0])):
129+
y_dict[num]["mean"] = np.mean([y[num] for y in y_sorted_all])
130+
y_dict[num]["stdv"] = np.std([y[num] for y in y_sorted_all])
131+
return x_dict, y_dict
132+
133+
134+
def _get_trendline_params(trend_dict):
135+
x_dict, y_dict = _get_trendline_dict(trend_dict)
136+
137+
x_values = []
138+
for key in x_dict.keys():
139+
x_values.append(min(x_dict[key]["values"]))
140+
x_values.append(max(x_dict[key]["values"]))
141+
142+
y_values_center = []
143+
y_values_upper = []
144+
y_values_lower = []
145+
for key in y_dict.keys():
146+
y_values_center.append(y_dict[key]["mean"])
147+
y_values_center.append(y_dict[key]["mean"])
148+
149+
y_values_upper.append(y_dict[key]["mean"] + y_dict[key]["stdv"])
150+
y_values_upper.append(y_dict[key]["mean"] + y_dict[key]["stdv"])
151+
152+
y_values_lower.append(y_dict[key]["mean"] - y_dict[key]["stdv"])
153+
y_values_lower.append(y_dict[key]["mean"] - y_dict[key]["stdv"])
154+
155+
return x_values, y_values_center, y_values_upper, y_values_lower
156+
157+
158+
def fig_06e_octave(otof_data, save_path, plot=False, use_alias=True, trendline=False):
159+
prism_style()
160+
label_size = 20
161+
162+
result = {"cochlea": [], "octave_band": [], "value": []}
163+
for name, values in otof_data.items():
164+
if use_alias:
165+
alias = COCHLEAE_DICT[name]["alias"]
166+
else:
167+
alias = name.replace("_", "").replace("0", "")
168+
169+
freq = values["frequency[kHz]"].values
170+
marker_labels = values["expression_classification"].values
171+
octave_binned = frequency_mapping(freq, marker_labels, animal="mouse", transduction_efficiency=True)
172+
173+
result["cochlea"].extend([alias] * len(octave_binned))
174+
result["octave_band"].extend(octave_binned.axes[0].values.tolist())
175+
result["value"].extend(octave_binned.values.tolist())
176+
177+
result = pd.DataFrame(result)
178+
bin_labels = pd.unique(result["octave_band"])
179+
band_to_x = {band: i for i, band in enumerate(bin_labels)}
180+
result["x_pos"] = result["octave_band"].map(band_to_x)
181+
182+
colors = {
183+
"M02": "#9C5027",
184+
"M03": "#67279C",
185+
}
186+
187+
fig, ax = plt.subplots(figsize=(8, 4))
188+
189+
offset = 0.08
190+
trend_dict = {}
191+
for num, (name, grp) in enumerate(result.groupby("cochlea")):
192+
x_sorted = grp["x_pos"]
193+
x_positions = [x - len(grp["x_pos"]) // 2 * offset + offset * num for x in grp["x_pos"]]
194+
ax.scatter(x_positions, grp["value"], marker="o", label=name, s=80, alpha=1, color=colors[name])
195+
196+
# y_values.append(list(grp["value"]))
197+
198+
if trendline:
199+
sorted_idx = np.argsort(x_positions)
200+
x_sorted = np.array(x_positions)[sorted_idx]
201+
y_sorted = np.array(grp["value"])[sorted_idx]
202+
trend_dict[name] = {"x_sorted": x_sorted,
203+
"y_sorted": y_sorted,
204+
}
205+
# central line
206+
if trendline:
207+
#mean, std = _get_trendline_params(y_values)
208+
x_sorted, y_sorted, y_sorted_upper, y_sorted_lower = _get_trendline_params(trend_dict)
209+
trend_center, = ax.plot(
210+
x_sorted,
211+
y_sorted,
212+
linestyle="dotted",
213+
color="gray",
214+
alpha=0.6,
215+
linewidth=3,
216+
zorder=2
217+
)
218+
# y_sorted_upper = np.array(mean) + np.array(std)
219+
# y_sorted_lower = np.array(mean) - np.array(std)
220+
# upper and lower standard deviation
221+
trend_upper, = ax.plot(
222+
x_sorted,
223+
y_sorted_upper,
224+
linestyle="solid",
225+
color="gray",
226+
alpha=0.08,
227+
zorder=0
228+
)
229+
trend_lower, = ax.plot(
230+
x_sorted,
231+
y_sorted_lower,
232+
linestyle="solid",
233+
color="gray",
234+
alpha=0.08,
235+
zorder=0
236+
)
237+
plt.fill_between(x_sorted, y_sorted_lower, y_sorted_upper,
238+
color="gray", alpha=0.05, interpolate=True)
239+
240+
ax.set_xticks(range(len(bin_labels)))
241+
ax.set_xticklabels(bin_labels)
242+
ax.set_xlabel("Octave band [kHz]", fontsize=label_size)
243+
244+
ax.set_ylabel("Expression efficiency")
245+
# plt.legend(title="Cochlea")
246+
plt.tight_layout()
247+
prism_cleanup_axes(ax)
248+
249+
if ".png" in save_path:
250+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
251+
else:
252+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
253+
254+
if plot:
255+
plt.show()
256+
else:
257+
plt.close()
258+
259+
25260
def fig_06b(save_path, plot=False):
26261
"""Box plot showing the counts for SGN and IHC per gerbil cochlea in comparison to literature values.
27262
"""
@@ -131,10 +366,15 @@ def fig_06d(save_path, plot=False):
131366

132367
def main():
133368
parser = argparse.ArgumentParser(description="Generate plots for Fig 6 of the cochlea paper.")
134-
parser.add_argument("figure_dir", type=str, help="Output directory for plots.", default="./panels")
369+
parser.add_argument("-f", "--figure_dir", type=str, help="Output directory for plots.", default="./panels")
135370
args = parser.parse_args()
136371
plot = False
137372

373+
otof_data = get_otof_data()
374+
plot_legend_fig06e(args.figure_dir)
375+
fig_06e_octave(otof_data, save_path=os.path.join(args.figure_dir, f"fig_06e.{FILE_EXTENSION}"), plot=plot,
376+
trendline=False, gr)
377+
138378
fig_06b(save_path=os.path.join(args.figure_dir, f"fig_06b.{FILE_EXTENSION}"), plot=plot)
139379
fig_06d(save_path=os.path.join(args.figure_dir, f"fig_06d.{FILE_EXTENSION}"), plot=plot)
140380

0 commit comments

Comments
 (0)