Skip to content

Commit 50d1fb0

Browse files
committed
Unify figure style
1 parent d3b59cd commit 50d1fb0

File tree

4 files changed

+109
-122
lines changed

4 files changed

+109
-122
lines changed

scripts/figures/plot_fig2.py

Lines changed: 53 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -135,81 +135,73 @@ def fig_02c(save_path, plot=False, all_versions=False):
135135
ihc_annotator = [0.958, 0.956, 0.957]
136136
syn_unet = [0.931, 0.905, 0.918]
137137

138-
# This is the version with IHC v4b segmentation:
139-
# 4th version of the network with optimized segmentation params
140-
version_1 = [sgn_unet, sgn_annotator, ihc_v4b, ihc_annotator, syn_unet]
141-
settings_1 = ["automatic", "manual", "automatic", "manual", "automatic"]
138+
setting = ["SGN", "IHC", "Synapse"]
142139

143140
# This is the version with IHC v4c segmentation:
144141
# 4th version of the network with optimized segmentation params and split of falsely merged IHCs
145-
version_2 = [sgn_unet, sgn_annotator, ihc_v4c, ihc_annotator, syn_unet]
146-
settings_2 = ["automatic", "manual", "automatic", "manual", "automatic"]
142+
manual = [sgn_annotator, ihc_annotator]
143+
automatic = [sgn_unet, ihc_v4c, syn_unet]
147144

148-
# This is the version with IHC v4c + filter segmentation:
149-
# 4th version of the network with optimized segmentation params and split of falsely merged IHCs
150-
# + filtering out IHCs with zero mapped synapses.
151-
version_3 = [sgn_unet, sgn_annotator, ihc_v4c_filter, ihc_annotator, syn_unet]
152-
settings_3 = ["automatic", "manual", "automatic", "manual", "automatic"]
145+
precision_manual = [i[0] for i in manual]
146+
recall_manual = [i[1] for i in manual]
147+
f1score_manual = [i[2] for i in manual]
153148

154-
if all_versions:
155-
versions = [version_1, version_2, version_3]
156-
settings = [settings_1, settings_2, settings_3]
157-
save_suffix = ["_v4b", "_v4c", "_v4c_filter"]
158-
save_paths = [save_path.split(".")[0] + i + "." + save_path.split(".")[1] for i in save_suffix]
159-
else:
160-
versions = [version_2]
161-
settings = [settings_2]
162-
save_suffix = ["_v4c"]
163-
save_paths = [save_path.split(".")[0] + i + "." + save_path.split(".")[1] for i in save_suffix]
149+
precision_automatic = [i[0] for i in automatic]
150+
recall_automatic = [i[1] for i in automatic]
151+
f1score_automatic = [i[2] for i in automatic]
164152

165-
for version, setting, save_path in zip(versions, settings, save_paths):
166-
precision = [i[0] for i in version]
167-
recall = [i[1] for i in version]
168-
f1score = [i[2] for i in version]
153+
descr_y = 0.72
169154

170-
descr_y = 0.72
155+
# Convert setting labels to numerical x positions
156+
x = np.array([0.8, 1.2, 1.8, 2.2, 3])
157+
x_manual = np.array([0.8, 1.8])
158+
x_automatic = np.array([1.2, 2.2, 3])
159+
offset = 0.08 # horizontal shift for scatter separation
171160

172-
# Convert setting labels to numerical x positions
173-
x = np.array([0.8, 1.2, 1.8, 2.2, 3])
174-
offset = 0.08 # horizontal shift for scatter separation
161+
# Plot
162+
fig, ax = plt.subplots(figsize=(8, 5))
175163

176-
# Plot
177-
fig, ax = plt.subplots(figsize=(8, 5))
164+
main_label_size = 22
165+
sub_label_size = 16
166+
main_tick_size = 16
167+
legendsize = 18
178168

179-
main_label_size = 22
180-
sub_label_size = 16
181-
main_tick_size = 16
182-
legendsize = 18
169+
color_pm = "#3AA67E"
170+
color_pa = "#17E69A"
171+
color_rm = "#438CA7"
172+
color_ra = "#17AEE6"
173+
color_fm = "#694BA6"
174+
color_fa = "#6322E6"
183175

184-
plt.scatter(x - offset, precision, label="Precision", marker="o", s=80)
185-
plt.scatter(x, recall, label="Recall", marker="^", s=80)
186-
plt.scatter(x + offset, f1score, label="F1-score", marker="*", s=80)
176+
plt.scatter(x_manual - offset, precision_manual, label="Precision manual", color=color_pm, marker="o", s=80)
177+
plt.scatter(x_manual, recall_manual, label="Recall manual", color=color_rm, marker="o", s=80)
178+
plt.scatter(x_manual + offset, f1score_manual, label="F1-score manual", color=color_fm, marker="o", s=80)
187179

188-
plt.text(1, descr_y, "SGN", fontsize=main_label_size, horizontalalignment="center")
189-
plt.text(2, descr_y, "IHC", fontsize=main_label_size, horizontalalignment="center")
190-
plt.text(3, descr_y, "Synapse", fontsize=main_label_size, horizontalalignment="center")
180+
plt.scatter(x_automatic - offset, precision_automatic, label="Precision automatic", color=color_pa, marker="s", s=80)
181+
plt.scatter(x_automatic, recall_automatic, label="Recall automatic", color=color_ra, marker="s", s=80)
182+
plt.scatter(x_automatic + offset, f1score_automatic, label="F1-score automatic", color=color_fa, marker="s", s=80)
191183

192-
# Labels and formatting
193-
plt.xticks(x, setting, fontsize=sub_label_size)
194-
plt.yticks(fontsize=main_tick_size)
195-
plt.ylabel("Value", fontsize=main_label_size)
196-
plt.ylim(0.76, 1)
197-
plt.legend(loc="lower right",
198-
fontsize=legendsize)
199-
plt.grid(axis="y", linestyle="--", alpha=0.5)
184+
# Labels and formatting
185+
plt.xticks([1,2,3], setting, fontsize=main_label_size)
186+
plt.yticks(fontsize=main_tick_size)
187+
plt.ylabel("Value", fontsize=main_label_size)
188+
plt.ylim(0.76, 1)
189+
plt.legend(loc="lower right",
190+
fontsize=legendsize)
191+
plt.grid(axis="y", linestyle="--", alpha=0.5)
200192

201-
plt.tight_layout()
202-
prism_cleanup_axes(ax)
193+
plt.tight_layout()
194+
prism_cleanup_axes(ax)
203195

204-
if ".png" in save_path:
205-
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
206-
else:
207-
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
196+
if ".png" in save_path:
197+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
198+
else:
199+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
208200

209-
if plot:
210-
plt.show()
211-
else:
212-
plt.close()
201+
if plot:
202+
plt.show()
203+
else:
204+
plt.close()
213205

214206

215207
# Load the synapse counts for all IHCs from the relevant tables.
@@ -423,8 +415,8 @@ def main():
423415
os.makedirs(args.figure_dir, exist_ok=True)
424416

425417
# Panes A and B: Qualitative comparison of visualization results.
426-
fig_02a_sgn(save_dir=args.figure_dir, plot=args.plot)
427-
fig_02b_ihc(save_dir=args.figure_dir, plot=args.plot)
418+
# fig_02a_sgn(save_dir=args.figure_dir, plot=args.plot)
419+
# fig_02b_ihc(save_dir=args.figure_dir, plot=args.plot)
428420

429421
# Panel C: Evaluation of the segmentation results:
430422
fig_02c(save_path=os.path.join(args.figure_dir, f"fig_02c.{FILE_EXTENSION}"), plot=args.plot, all_versions=False)

scripts/figures/plot_fig3.py

Lines changed: 34 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313
from matplotlib import cm, colors
1414

1515
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target
16-
from util import sliding_runlength_sum, frequency_mapping, prism_style, prism_cleanup_axes, SYNAPSE_DIR_ROOT
16+
from util import sliding_runlength_sum, frequency_mapping, SYNAPSE_DIR_ROOT
1717

18-
# INPUT_ROOT = "/home/pape/Work/my_projects/flamingo-tools/scripts/M_LR_000227_R/scale3"
19-
INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/frequency_mapping/M_LR_000227_R/scale3"
20-
FILE_EXTENSION = "png"
18+
INPUT_ROOT = "/home/martin/Documents/lightsheet-cochlea/M_LR_000227_R"
2119

2220
TYPE_TO_CHANNEL = {
2321
"Type-Ia": "CR",
@@ -93,7 +91,7 @@ def get_tonotopic_data():
9391
return pickle.load(f)
9492

9593

96-
def _plot_colormap(vol, title, plot, save_path):
94+
def _plot_colormap(vol, title, plot, save_path, cmap="viridis"):
9795
# before creating the figure:
9896
matplotlib.rcParams.update({
9997
"font.size": 14, # base font size
@@ -110,10 +108,16 @@ def _plot_colormap(vol, title, plot, save_path):
110108

111109
freq_min = np.min(np.nonzero(vol))
112110
freq_max = vol.max()
113-
norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True)
114-
cmap = plt.get_cmap("viridis")
111+
# norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True)
112+
norm = colors.LogNorm(vmin=freq_min, vmax=freq_max, clip=True)
113+
tick_values = np.array([10, 20, 40, 80])
114+
115+
cmap = plt.get_cmap(cmap)
115116

116-
cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal")
117+
cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal",
118+
ticks=tick_values)
119+
cb.ax.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
120+
cb.ax.xaxis.set_minor_locator(matplotlib.ticker.NullLocator())
117121
cb.set_label("Frequency [kHz]")
118122
plt.title(title)
119123
plt.tight_layout()
@@ -127,19 +131,29 @@ def _plot_colormap(vol, title, plot, save_path):
127131
plt.close()
128132

129133

130-
def fig_03a(save_path, plot, plot_napari):
134+
def fig_03a(save_path, plot, plot_napari, cmap="viridis"):
131135
path_ihc = os.path.join(INPUT_ROOT, "frequencies_IHC_v4c.tif")
132136
path_sgn = os.path.join(INPUT_ROOT, "frequencies_SGN_v2.tif")
133137
sgn = imageio.imread(path_sgn)
134138
ihc = imageio.imread(path_ihc)
135-
_plot_colormap(sgn, title="Tonotopic Mapping", plot=plot, save_path=save_path)
139+
_plot_colormap(sgn, title="Tonotopic Mapping", plot=plot, save_path=save_path, cmap=cmap)
136140

137141
# Show the image in napari for rendering.
138142
if plot_napari:
139143
import napari
144+
from napari.utils import Colormap
145+
# cmap = plt.get_cmap(cmap)
146+
mpl_cmap = plt.get_cmap(cmap)
147+
148+
# Sample it into an array of RGBA values
149+
colors = mpl_cmap(np.linspace(0, 1, 256))
150+
151+
# Wrap into napari Colormap
152+
napari_cmap = Colormap(colors, name=f"{cmap}_custom")
153+
140154
v = napari.Viewer()
141-
v.add_image(ihc, colormap="viridis")
142-
v.add_image(sgn, colormap="viridis")
155+
v.add_image(ihc, colormap=napari_cmap)
156+
v.add_image(sgn, colormap=napari_cmap)
143157
napari.run()
144158

145159

@@ -180,8 +194,7 @@ def fig_03c_rl(save_path, plot=False):
180194
plt.close()
181195

182196

183-
def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True, trendlines=False):
184-
prism_style()
197+
def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True):
185198
ihc_version = "ihc_counts_v4c"
186199
tables = glob(os.path.join(SYNAPSE_DIR_ROOT, ihc_version, "ihc_count_M_LR*.tsv"))
187200
assert len(tables) == 4, len(tables)
@@ -207,55 +220,16 @@ def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True, trendl
207220
result["x_pos"] = result["octave_band"].map(band_to_x)
208221

209222
fig, ax = plt.subplots(figsize=(8, 4))
210-
trend_dict = {}
211223
for name, grp in result.groupby("cochlea"):
212224
ax.scatter(grp["x_pos"], grp["value"], label=name, s=60, alpha=0.8)
213225

214-
if trendlines:
215-
x_positions = grp["x_pos"]
216-
sorted_idx = np.argsort(x_positions)
217-
x_sorted = np.array(x_positions)[sorted_idx]
218-
y_sorted = np.array(grp["value"])[sorted_idx]
219-
trend_dict[name] = {"x_sorted": x_sorted,
220-
"y_sorted": y_sorted,
221-
}
222-
223-
if trendlines:
224-
def get_trendline_values(trend_dict):
225-
x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys()][0]
226-
y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys()]
227-
y_sorted = []
228-
for num in range(len(x_sorted)):
229-
y_sorted.append(np.mean([y[num] for y in y_sorted_all]))
230-
return x_sorted, y_sorted
231-
232-
# Trendline left
233-
x_sorted, y_sorted = get_trendline_values(trend_dict)
234-
235-
trend, = ax.plot(
236-
x_sorted,
237-
y_sorted,
238-
linestyle="dotted",
239-
color="grey",
240-
alpha=0.7
241-
)
242-
243-
# trendline_legend = ax.legend(handles=[trend], loc='lower center')
244-
# trendline_legend = ax.legend(
245-
# handles=[trend],
246-
# labels=["Trendline"],
247-
# loc="upper left"
248-
# )
249-
# # Add the legend manually to the Axes.
250-
# ax.add_artist(trendline_legend)
251-
252226
ax.set_xticks(range(len(bin_labels)))
253227
ax.set_xticklabels(bin_labels)
254228
ax.set_xlabel("Octave band (kHz)")
255229

256-
ax.set_ylabel("Average Ribbon Synapse Count per IHC", fontsize=10)
230+
ax.set_ylabel("Average Ribbon Synapse Count per IHC")
231+
ax.set_title("Ribbon synapse count per octave band")
257232
plt.legend(title="Cochlea")
258-
prism_cleanup_axes(ax)
259233

260234
if ".png" in save_path:
261235
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
@@ -345,15 +319,16 @@ def main():
345319
tonotopic_data = get_tonotopic_data()
346320

347321
# Panel A: Tonotopic mapping of SGNs and IHCs (rendering in napari + heatmap)
348-
# fig_03a(save_path=os.path.join(args.figure_dir, f"fig_03a_cmap.{FILE_EXTENSION}"),
349-
# plot=args.plot, plot_napari=True)
322+
cmap = "plasma"
323+
fig_03a(save_path=os.path.join(args.figure_dir, f"fig_03a_cmap_{cmap}.{FILE_EXTENSION}"),
324+
plot=args.plot, plot_napari=True, cmap=cmap)
350325

351326
# Panel C: Spatial distribution of synapses across the cochlea.
352327
# We have two options: running sum over the runlength or per octave band
353328
# fig_03c_rl(save_path=os.path.join(args.figure_dir, f"fig_03c_runlength.{FILE_EXTENSION}"), plot=args.plot)
354329
fig_03c_octave(tonotopic_data=tonotopic_data,
355-
save_path=os.path.join(args.figure_dir, f"fig_03c_octave.{FILE_EXTENSION}"),
356-
plot=args.plot, trendlines=True)
330+
save_path=os.path.join(args.figure_dir, f"fig_03c_octave.{FILE_EXTENSION}"),
331+
plot=args.plot)
357332

358333
# Panel D: Spatial distribution of SGN sub-types.
359334
# fig_03d_fraction(save_path=os.path.join(args.figure_dir, f"fig_03d_fraction.{FILE_EXTENSION}"), plot=args.plot)

scripts/figures/plot_fig4.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,16 +309,21 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali
309309
# Assign a color to each cochlea (ignoring side)
310310
cochleas = sorted({name_lr[:-1] for name_lr in result["cochlea"].unique()})
311311
colors = plt.cm.tab10.colors # pick a colormap
312+
colors = ["#DB5748", "#DB4B6F", "#DB49C6", "#B748DB", "#8748DB"]
313+
312314
color_map = {cochlea: colors[i % len(colors)] for i, cochlea in enumerate(cochleas)}
315+
316+
313317
if len(cochleas) == 1:
314318
color_map = {"L": colors[0], "R": colors[1]}
315319

316320
# Track which cochlea names we have already added to the legend
317321
legend_added = set()
318322

323+
offset = 0.02
319324
trend_dict = {}
320325

321-
for name_lr, grp in result.groupby("cochlea"):
326+
for num, (name_lr, grp) in enumerate(result.groupby("cochlea")):
322327
name, side = name_lr[:-1], name_lr[-1]
323328
if len(cochleas) == 1:
324329
label_name = name_lr
@@ -327,7 +332,7 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali
327332
label_name = name
328333
color = color_map[name]
329334

330-
x_positions = grp["x_pos"] + offset_map[side]
335+
x_positions = grp["x_pos"] + offset_map[side] - len(cochleas) / 2 * offset + offset * num
331336
ax.scatter(
332337
x_positions,
333338
grp["value"],

scripts/figures/util.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ def ax_prism_boxplot(ax, data, positions=None, color="tab:blue"):
3131
return bp
3232

3333

34+
prism_palette = [
35+
"#4E79A7", # blue
36+
"#F28E2B", # orange
37+
"#E15759", # red
38+
"#76B7B2", # teal
39+
"#59A14F", # green
40+
"#EDC948", # yellow
41+
"#B07AA1", # purple
42+
"#FF9DA7", # pink
43+
"#9C755F", # brown
44+
"#BAB0AC" # gray
45+
]
46+
47+
3448
def prism_style():
3549
plt.style.use("default") # reset any active styles
3650
plt.rcParams.update({
@@ -44,6 +58,7 @@ def prism_style():
4458
"axes.linewidth": 1.2,
4559
"axes.labelsize": 14,
4660
"axes.labelweight": "bold",
61+
"axes.prop_cycle": plt.cycler("color", prism_palette),
4762

4863
# Ticks
4964
"xtick.direction": "out",

0 commit comments

Comments
 (0)