@@ -87,6 +87,49 @@ def get_sgn_counts(cochlea):
8787 return frequencies , values
8888
8989
90+ def average_densities (curves , * , nbins = 512 , weights = None , renormalize = True ):
91+ if len (curves ) == 0 :
92+ raise ValueError ("curves must be non-empty" )
93+
94+ # Global domain across all inputs
95+ xmin = min (g [0 ][0 ] for g in curves )
96+ xmax = max (g [0 ][- 1 ] for g in curves )
97+ if not np .isfinite ([xmin , xmax ]).all () or xmax <= xmin :
98+ raise ValueError ("Invalid global domain from inputs." )
99+
100+ grid_common = np .linspace (xmin , xmax , nbins )
101+ interp_dens = []
102+
103+ for grid , dens in curves :
104+ grid = np .asarray (grid , float )
105+ dens = np .asarray (dens , float )
106+ # Interpolate onto common grid; outside each curve's support -> 0
107+ interp = np .interp (grid_common , grid , dens , left = 0.0 , right = 0.0 )
108+ # Clip tiny negatives that may appear from numeric noise
109+ interp_dens .append (np .clip (interp , 0.0 , np .inf ))
110+
111+ M = np .vstack (interp_dens ) # shape: (n_curves, nbins)
112+
113+ if weights is None :
114+ w = np .ones (M .shape [0 ], float )
115+ else :
116+ w = np .asarray (weights , float )
117+ if w .shape [0 ] != M .shape [0 ]:
118+ raise ValueError ("weights must have same length as number of curves" )
119+ if np .any (w < 0 ):
120+ raise ValueError ("weights must be non-negative" )
121+ w = w / w .sum ()
122+
123+ mean_density = (w [:, None ] * M ).sum (axis = 0 )
124+
125+ if renormalize :
126+ area = np .trapz (mean_density , grid_common )
127+ if area > 0 :
128+ mean_density /= area
129+
130+ return grid_common , mean_density
131+
132+
90133def check_implementation ():
91134 cochlea = "G_EK_000049_L"
92135 analyze_cochlea (cochlea , plot = True )
@@ -129,8 +172,60 @@ def compare_cochleae(cochleae, animal, plot_density=True, plot_tonotopy=True):
129172 plt .show ()
130173
131174
132- # TODO: implement the same for mouse cochleae (healthy vs. opto treatment)
133- # also show this in tonotopic mapping
175+ def compare_cochlea_groups (cochlea_groups , animal , plot_density = True , plot_tonotopy = True ):
176+
177+ if plot_density :
178+ fix , axes = plt .subplots (2 , sharey = True , sharex = True )
179+ for name , cochleae in cochlea_groups .items ():
180+ group_values = []
181+ for cochlea in cochleae :
182+ grid , density = analyze_cochlea (cochlea , plot = False )
183+ axes [0 ].plot (grid , density , lw = 1 , label = cochlea , alpha = 0.8 )
184+ group_values .append ((grid , density ))
185+ group_grid , group_density = average_densities (group_values , nbins = len (grid ), renormalize = False )
186+ axes [1 ].plot (group_grid , group_density , label = name , lw = 2 )
187+
188+ for ax in axes :
189+ ax .set_xlabel ("Length [µm]" )
190+ ax .set_ylabel ("Density [SGN/µm]" )
191+ ax .legend ()
192+ plt .tight_layout ()
193+ plt .show ()
194+
195+ if plot_tonotopy :
196+ from util import frequency_mapping
197+
198+ fig , axes = plt .subplots (2 , sharey = True )
199+ for name , cochleae in cochlea_groups .items ():
200+ grp_values = []
201+ for cochlea in cochleae :
202+ frequencies , values = get_sgn_counts (cochlea )
203+ sgns_per_band = frequency_mapping (
204+ frequencies , values , animal = animal , aggregation = "sum"
205+ )
206+ bin_labels = sgns_per_band .index
207+ binned_counts = sgns_per_band .values
208+
209+ band_to_x = {band : i for i , band in enumerate (bin_labels )}
210+ x_positions = bin_labels .map (band_to_x )
211+ axes [0 ].scatter (x_positions , binned_counts , marker = "o" , label = cochlea , s = 80 )
212+
213+ grp_values .append (binned_counts )
214+
215+ grp_values = np .array (grp_values )
216+ grp_mean = grp_values .mean (axis = 0 )
217+ grp_std = grp_values .std (axis = 0 )
218+
219+ axes [1 ].plot (x_positions , grp_mean , lw = 2 , label = name )
220+ axes [1 ].fill_between (x_positions , grp_mean - grp_std , grp_mean + grp_std , alpha = 0.3 )
221+
222+ for ax in axes :
223+ ax .set_xticks (range (len (bin_labels )))
224+ ax .set_xticklabels (bin_labels )
225+ ax .set_xlabel ("Octave band [kHz]" )
226+ ax .set_ylabel ("SGN Count" )
227+ ax .legend ()
228+ plt .show ()
134229
135230
136231# The visualization has to be improved to make plots understandable.
@@ -144,15 +239,27 @@ def main():
144239 # Comparison for Mouse.
145240 # NOTE: There is some problem with M_LR_000143_L and "M_LR_000153_L"
146241 # I have removed the corresponding pairs for now, but we should investigate and add back.
147- cochleae = [
148- # Healthy reference cochleae.
242+
243+ # Healthy reference cochleae.
244+ reference_cochleae = [
149245 "M_LR_000226_L" , "M_LR_000226_R" , "M_LR_000227_L" , "M_LR_000227_R" ,
150- # Right un-injected cochleae.
246+ ]
247+ # Right un-injected cochleae.
248+ uninjected_cochleae = [
151249 "M_LR_000144_R" , "M_LR_000145_R" , "M_LR_000155_R" , "M_LR_000189_R" ,
152- # Left injected cochleae.
250+ ]
251+ # Left injected cochleae.
252+ injected_cochleae = [
153253 "M_LR_000144_L" , "M_LR_000145_L" , "M_LR_000155_L" , "M_LR_000189_L" ,
154254 ]
155- compare_cochleae (cochleae , animal = "mouse" )
255+ compare_cochlea_groups (
256+ {
257+ "reference" : reference_cochleae ,
258+ "uninjected" : uninjected_cochleae ,
259+ "injected" : injected_cochleae ,
260+ },
261+ animal = "mouse" , plot_tonotopy = True , plot_density = True ,
262+ )
156263
157264
158265if __name__ == "__main__" :
0 commit comments