Skip to content

Commit e5b932c

Browse files
committed
plot density map with msecs on x axis
1 parent c30f587 commit e5b932c

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

src/spikeinterface/widgets/unit_summary.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
180180
**unitwaveformdensitymapwidget_kwargs,
181181
)
182182
col_counter += 1
183+
ax_waveform_density.set_xlabel(None)
183184
ax_waveform_density.set_ylabel(None)
184185

185186
if sorting_analyzer.has_extension("correlograms"):

src/spikeinterface/widgets/unit_waveforms_density_map.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(
121121
wfs = wfs_
122122

123123
# make histogram density
124-
wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) # num_spikes x times*num_channels
124+
wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) # num_spikes x (num_channels * timepoints)
125125
hists_per_timepoint = [np.histogram(one_timepoint, bins=bins)[0] for one_timepoint in wfs_flat.T]
126126
hist2d = np.stack(hists_per_timepoint)
127127

@@ -157,6 +157,7 @@ def __init__(
157157
bin_min=bin_min,
158158
bin_max=bin_max,
159159
all_hist2d=all_hist2d,
160+
sampling_frequency=sorting_analyzer.sampling_frequency,
160161
templates_flat=templates_flat,
161162
template_width=wfs.shape[1],
162163
)
@@ -173,37 +174,36 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
173174
backend_kwargs["num_axes"] = 1 if dp.same_axis else len(dp.unit_ids)
174175
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)
175176

177+
freq_khz = dp.sampling_frequency / 1000 # samples / msec
176178
if dp.same_axis:
177-
ax = self.ax
178179
hist2d = dp.all_hist2d
179-
im = ax.imshow(
180+
x_max = len(hist2d) / freq_khz # in milliseconds
181+
self.ax.imshow(
180182
hist2d.T,
181183
interpolation="nearest",
182184
origin="lower",
183185
aspect="auto",
184-
extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max),
186+
extent=(0, x_max, dp.bin_min, dp.bin_max),
185187
cmap="hot",
186188
)
187189
else:
188-
for unit_index, unit_id in enumerate(dp.unit_ids):
190+
for ax, unit_id in zip(self.axes.flatten(), dp.unit_ids):
189191
hist2d = dp.all_hist2d[unit_id]
190-
ax = self.axes.flatten()[unit_index]
191-
im = ax.imshow(
192+
x_max = len(hist2d) / freq_khz # in milliseconds
193+
ax.imshow(
192194
hist2d.T,
193195
interpolation="nearest",
194196
origin="lower",
195197
aspect="auto",
196-
extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max),
198+
extent=(0, x_max, dp.bin_min, dp.bin_max),
197199
cmap="hot",
198200
)
199201

200202
for unit_index, unit_id in enumerate(dp.unit_ids):
201-
if dp.same_axis:
202-
ax = self.ax
203-
else:
204-
ax = self.axes.flatten()[unit_index]
203+
ax = self.ax if dp.same_axis else self.axes.flatten()[unit_index]
205204
color = dp.unit_colors[unit_id]
206-
ax.plot(dp.templates_flat[unit_id], color=color, lw=1)
205+
x = np.arange(len(dp.templates_flat[unit_id])) / freq_khz
206+
ax.plot(x, dp.templates_flat[unit_id], color=color, lw=1)
207207

208208
# final cosmetics
209209
for unit_index, unit_id in enumerate(dp.unit_ids):
@@ -216,11 +216,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
216216
chan_inds = dp.channel_inds[unit_id]
217217
for i, chan_ind in enumerate(chan_inds):
218218
if i != 0:
219-
ax.axvline(i * dp.template_width, color="w", lw=3)
219+
ax.axvline(i * dp.template_width / freq_khz, color="w", lw=3)
220220
channel_id = dp.channel_ids[chan_ind]
221-
x = i * dp.template_width + dp.template_width // 2
221+
x = (i + 0.5) * dp.template_width / freq_khz
222222
y = (dp.bin_max + dp.bin_min) / 2.0
223223
ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center")
224224

225-
ax.set_xticks([])
225+
ax.set_xlabel('Time [ms]')
226226
ax.set_ylabel(f"unit_id {unit_id}")

0 commit comments

Comments
 (0)