Skip to content

Commit d1de337

Browse files
Update SGN density plots
1 parent c642d14 commit d1de337

File tree

1 file changed

+114
-7
lines changed

1 file changed

+114
-7
lines changed

scripts/measurements/density_analysis.py

Lines changed: 114 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,49 @@ def get_sgn_counts(cochlea):
8787
return frequencies, values
8888

8989

90+
def average_densities(curves, *, nbins=512, weights=None, renormalize=True):
91+
if len(curves) == 0:
92+
raise ValueError("curves must be non-empty")
93+
94+
# Global domain across all inputs
95+
xmin = min(g[0][0] for g in curves)
96+
xmax = max(g[0][-1] for g in curves)
97+
if not np.isfinite([xmin, xmax]).all() or xmax <= xmin:
98+
raise ValueError("Invalid global domain from inputs.")
99+
100+
grid_common = np.linspace(xmin, xmax, nbins)
101+
interp_dens = []
102+
103+
for grid, dens in curves:
104+
grid = np.asarray(grid, float)
105+
dens = np.asarray(dens, float)
106+
# Interpolate onto common grid; outside each curve's support -> 0
107+
interp = np.interp(grid_common, grid, dens, left=0.0, right=0.0)
108+
# Clip tiny negatives that may appear from numeric noise
109+
interp_dens.append(np.clip(interp, 0.0, np.inf))
110+
111+
M = np.vstack(interp_dens) # shape: (n_curves, nbins)
112+
113+
if weights is None:
114+
w = np.ones(M.shape[0], float)
115+
else:
116+
w = np.asarray(weights, float)
117+
if w.shape[0] != M.shape[0]:
118+
raise ValueError("weights must have same length as number of curves")
119+
if np.any(w < 0):
120+
raise ValueError("weights must be non-negative")
121+
w = w / w.sum()
122+
123+
mean_density = (w[:, None] * M).sum(axis=0)
124+
125+
if renormalize:
126+
area = np.trapz(mean_density, grid_common)
127+
if area > 0:
128+
mean_density /= area
129+
130+
return grid_common, mean_density
131+
132+
90133
def check_implementation():
91134
cochlea = "G_EK_000049_L"
92135
analyze_cochlea(cochlea, plot=True)
@@ -129,8 +172,60 @@ def compare_cochleae(cochleae, animal, plot_density=True, plot_tonotopy=True):
129172
plt.show()
130173

131174

132-
# TODO: implement the same for mouse cochleae (healthy vs. opto treatment)
133-
# also show this in tonotopic mapping
175+
def compare_cochlea_groups(cochlea_groups, animal, plot_density=True, plot_tonotopy=True):
176+
177+
if plot_density:
178+
fix, axes = plt.subplots(2, sharey=True, sharex=True)
179+
for name, cochleae in cochlea_groups.items():
180+
group_values = []
181+
for cochlea in cochleae:
182+
grid, density = analyze_cochlea(cochlea, plot=False)
183+
axes[0].plot(grid, density, lw=1, label=cochlea, alpha=0.8)
184+
group_values.append((grid, density))
185+
group_grid, group_density = average_densities(group_values, nbins=len(grid), renormalize=False)
186+
axes[1].plot(group_grid, group_density, label=name, lw=2)
187+
188+
for ax in axes:
189+
ax.set_xlabel("Length [µm]")
190+
ax.set_ylabel("Density [SGN/µm]")
191+
ax.legend()
192+
plt.tight_layout()
193+
plt.show()
194+
195+
if plot_tonotopy:
196+
from util import frequency_mapping
197+
198+
fig, axes = plt.subplots(2, sharey=True)
199+
for name, cochleae in cochlea_groups.items():
200+
grp_values = []
201+
for cochlea in cochleae:
202+
frequencies, values = get_sgn_counts(cochlea)
203+
sgns_per_band = frequency_mapping(
204+
frequencies, values, animal=animal, aggregation="sum"
205+
)
206+
bin_labels = sgns_per_band.index
207+
binned_counts = sgns_per_band.values
208+
209+
band_to_x = {band: i for i, band in enumerate(bin_labels)}
210+
x_positions = bin_labels.map(band_to_x)
211+
axes[0].scatter(x_positions, binned_counts, marker="o", label=cochlea, s=80)
212+
213+
grp_values.append(binned_counts)
214+
215+
grp_values = np.array(grp_values)
216+
grp_mean = grp_values.mean(axis=0)
217+
grp_std = grp_values.std(axis=0)
218+
219+
axes[1].plot(x_positions, grp_mean, lw=2, label=name)
220+
axes[1].fill_between(x_positions, grp_mean - grp_std, grp_mean + grp_std, alpha=0.3)
221+
222+
for ax in axes:
223+
ax.set_xticks(range(len(bin_labels)))
224+
ax.set_xticklabels(bin_labels)
225+
ax.set_xlabel("Octave band [kHz]")
226+
ax.set_ylabel("SGN Count")
227+
ax.legend()
228+
plt.show()
134229

135230

136231
# The visualization has to be improved to make plots understandable.
@@ -144,15 +239,27 @@ def main():
144239
# Comparison for Mouse.
145240
# NOTE: There is some problem with M_LR_000143_L and "M_LR_000153_L"
146241
# I have removed the corresponding pairs for now, but we should investigate and add back.
147-
cochleae = [
148-
# Healthy reference cochleae.
242+
243+
# Healthy reference cochleae.
244+
reference_cochleae = [
149245
"M_LR_000226_L", "M_LR_000226_R", "M_LR_000227_L", "M_LR_000227_R",
150-
# Right un-injected cochleae.
246+
]
247+
# Right un-injected cochleae.
248+
uninjected_cochleae = [
151249
"M_LR_000144_R", "M_LR_000145_R", "M_LR_000155_R", "M_LR_000189_R",
152-
# Left injected cochleae.
250+
]
251+
# Left injected cochleae.
252+
injected_cochleae = [
153253
"M_LR_000144_L", "M_LR_000145_L", "M_LR_000155_L", "M_LR_000189_L",
154254
]
155-
compare_cochleae(cochleae, animal="mouse")
255+
compare_cochlea_groups(
256+
{
257+
"reference": reference_cochleae,
258+
"uninjected": uninjected_cochleae,
259+
"injected": injected_cochleae,
260+
},
261+
animal="mouse", plot_tonotopy=True, plot_density=True,
262+
)
156263

157264

158265
if __name__ == "__main__":

0 commit comments

Comments
 (0)