@@ -121,7 +121,7 @@ def __init__(
121121 wfs = wfs_
122122
123123 # make histogram density
124- wfs_flat = wfs .swapaxes (1 , 2 ).reshape (wfs .shape [0 ], - 1 ) # num_spikes x times* num_channels
124+ wfs_flat = wfs .swapaxes (1 , 2 ).reshape (wfs .shape [0 ], - 1 ) # num_spikes x ( num_channels * timepoints)
125125 hists_per_timepoint = [np .histogram (one_timepoint , bins = bins )[0 ] for one_timepoint in wfs_flat .T ]
126126 hist2d = np .stack (hists_per_timepoint )
127127
@@ -157,6 +157,7 @@ def __init__(
157157 bin_min = bin_min ,
158158 bin_max = bin_max ,
159159 all_hist2d = all_hist2d ,
160+ sampling_frequency = sorting_analyzer .sampling_frequency ,
160161 templates_flat = templates_flat ,
161162 template_width = wfs .shape [1 ],
162163 )
@@ -173,37 +174,36 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
173174 backend_kwargs ["num_axes" ] = 1 if dp .same_axis else len (dp .unit_ids )
174175 self .figure , self .axes , self .ax = make_mpl_figure (** backend_kwargs )
175176
177+ freq_khz = dp .sampling_frequency / 1000 # samples / msec
176178 if dp .same_axis :
177- ax = self .ax
178179 hist2d = dp .all_hist2d
179- im = ax .imshow (
180+ x_max = len (hist2d ) / freq_khz # in milliseconds
181+ self .ax .imshow (
180182 hist2d .T ,
181183 interpolation = "nearest" ,
182184 origin = "lower" ,
183185 aspect = "auto" ,
184- extent = (0 , hist2d . shape [ 0 ] , dp .bin_min , dp .bin_max ),
186+ extent = (0 , x_max , dp .bin_min , dp .bin_max ),
185187 cmap = "hot" ,
186188 )
187189 else :
188- for unit_index , unit_id in enumerate ( dp .unit_ids ):
190+ for ax , unit_id in zip ( self . axes . flatten (), dp .unit_ids ):
189191 hist2d = dp .all_hist2d [unit_id ]
190- ax = self . axes . flatten ()[ unit_index ]
191- im = ax .imshow (
192+ x_max = len ( hist2d ) / freq_khz # in milliseconds
193+ ax .imshow (
192194 hist2d .T ,
193195 interpolation = "nearest" ,
194196 origin = "lower" ,
195197 aspect = "auto" ,
196- extent = (0 , hist2d . shape [ 0 ] , dp .bin_min , dp .bin_max ),
198+ extent = (0 , x_max , dp .bin_min , dp .bin_max ),
197199 cmap = "hot" ,
198200 )
199201
200202 for unit_index , unit_id in enumerate (dp .unit_ids ):
201- if dp .same_axis :
202- ax = self .ax
203- else :
204- ax = self .axes .flatten ()[unit_index ]
203+ ax = self .ax if dp .same_axis else self .axes .flatten ()[unit_index ]
205204 color = dp .unit_colors [unit_id ]
206- ax .plot (dp .templates_flat [unit_id ], color = color , lw = 1 )
205+ x = np .arange (len (dp .templates_flat [unit_id ])) / freq_khz
206+ ax .plot (x , dp .templates_flat [unit_id ], color = color , lw = 1 )
207207
208208 # final cosmetics
209209 for unit_index , unit_id in enumerate (dp .unit_ids ):
@@ -216,11 +216,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
216216 chan_inds = dp .channel_inds [unit_id ]
217217 for i , chan_ind in enumerate (chan_inds ):
218218 if i != 0 :
219- ax .axvline (i * dp .template_width , color = "w" , lw = 3 )
219+ ax .axvline (i * dp .template_width / freq_khz , color = "w" , lw = 3 )
220220 channel_id = dp .channel_ids [chan_ind ]
221- x = i * dp . template_width + dp .template_width // 2
221+ x = ( i + 0.5 ) * dp .template_width / freq_khz
222222 y = (dp .bin_max + dp .bin_min ) / 2.0
223223 ax .text (x , y , f"chan_id { channel_id } " , color = "w" , ha = "center" , va = "center" )
224224
225- ax .set_xticks ([] )
225+ ax .set_xlabel ( 'Time [ms]' )
226226 ax .set_ylabel (f"unit_id { unit_id } " )
0 commit comments