Skip to content

Commit 69e29a6

Browse files
committed
Add main_setting to use recorting times (and other time-related fixes)
1 parent 0443dcb commit 69e29a6

File tree

11 files changed

+279
-135
lines changed

11 files changed

+279
-135
lines changed

spikeinterface_gui/backend_qt.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class SignalNotifier(QT.QObject):
1919
channel_visibility_changed = QT.pyqtSignal()
2020
manual_curation_updated = QT.pyqtSignal()
2121
time_info_updated = QT.pyqtSignal()
22+
use_times_updated = QT.pyqtSignal()
2223
unit_color_changed = QT.pyqtSignal()
2324

2425
def __init__(self, parent=None, view=None):
@@ -40,6 +41,9 @@ def notify_manual_curation_updated(self):
4041
def notify_time_info_updated(self):
4142
self.time_info_updated.emit()
4243

44+
def notify_use_times_updated(self):
45+
self.use_times_updated.emit()
46+
4347
def notify_unit_color_changed(self):
4448
self.unit_color_changed.emit()
4549

@@ -63,6 +67,7 @@ def connect_view(self, view):
6367
view.notifier.channel_visibility_changed.connect(self.on_channel_visibility_changed)
6468
view.notifier.manual_curation_updated.connect(self.on_manual_curation_updated)
6569
view.notifier.time_info_updated.connect(self.on_time_info_updated)
70+
view.notifier.use_times_updated.connect(self.on_use_times_updated)
6671
view.notifier.unit_color_changed.connect(self.on_unit_color_changed)
6772

6873
def on_spike_selection_changed(self):
@@ -110,7 +115,16 @@ def on_time_info_updated(self):
110115
# do not refresh it self
111116
continue
112117
view.on_time_info_updated()
113-
118+
119+
def on_use_times_updated(self):
120+
if not self._active:
121+
return
122+
for view in self.controller.views:
123+
if view.qt_widget == self.sender().parent():
124+
# do not refresh it self
125+
continue
126+
view.on_use_times_updated()
127+
114128
def on_unit_color_changed(self):
115129
if not self._active:
116130
return

spikeinterface_gui/basescatterview.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q
3636
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
3737

3838

39-
def get_unit_data(self, unit_id, seg_index=0):
40-
inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index)
41-
spike_times = self.controller.spikes["sample_index"][inds] / self.controller.sampling_frequency
39+
def get_unit_data(self, unit_id, segment_index=0):
40+
inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index)
41+
spike_indices = self.controller.spikes["sample_index"][inds]
42+
spike_times = self.controller.sample_index_to_time(spike_indices)
4243
spike_data = self.spike_data[inds]
4344
ptp = np.ptp(spike_data)
4445
hist_min, hist_max = [np.min(spike_data) - 0.2 * ptp, np.max(spike_data) + 0.2 * ptp]
@@ -53,8 +54,8 @@ def get_unit_data(self, unit_id, seg_index=0):
5354

5455
return spike_times, spike_data, hist_count, hist_bins, inds
5556

56-
def get_selected_spikes_data(self, seg_index=0, visible_inds=None):
57-
sl = self.controller.segment_slices[seg_index]
57+
def get_selected_spikes_data(self, segment_index=0, visible_inds=None):
58+
sl = self.controller.segment_slices[segment_index]
5859
spikes_in_seg = self.controller.spikes[sl]
5960
selected_indices = self.controller.get_indices_spike_selected()
6061
if visible_inds is not None:
@@ -85,7 +86,7 @@ def select_all_spikes_from_lasso(self, keep_already_selected=False):
8586
for segment_index, vertices in self._lasso_vertices.items():
8687
if vertices is None:
8788
continue
88-
spike_inds = self.controller.get_spike_indices(visible_unit_id, seg_index=segment_index)
89+
spike_inds = self.controller.get_spike_indices(visible_unit_id, segment_index=segment_index)
8990
spike_times = self.controller.spikes["sample_index"][spike_inds] / fs
9091
spike_data = self.spike_data[spike_inds]
9192

@@ -119,7 +120,7 @@ def split(self):
119120

120121
if self.controller.num_segments > 1:
121122
# check that lasso vertices are defined for all segments
122-
if not all(self._lasso_vertices[seg_index] is not None for seg_index in range(self.controller.num_segments)):
123+
if not all(self._lasso_vertices[segment_index] is not None for segment_index in range(self.controller.num_segments)):
123124
# Use the new continue_from_user pattern
124125
self.continue_from_user(
125126
"Not all segments have lasso selection. "
@@ -163,6 +164,12 @@ def on_unit_visibility_changed(self):
163164
self._current_selected = self.controller.get_indices_spike_selected().size
164165
self.refresh()
165166

167+
def on_time_info_updated(self):
168+
return self.refresh()
169+
170+
def on_use_times_updated(self):
171+
return self.refresh()
172+
166173
## QT zone ##
167174
def _qt_make_layout(self):
168175
from .myqt import QT
@@ -174,8 +181,8 @@ def _qt_make_layout(self):
174181
tb = self.qt_widget.view_toolbar
175182
self.combo_seg = QT.QComboBox()
176183
tb.addWidget(self.combo_seg)
177-
self.combo_seg.addItems([ f'Segment {seg_index}' for seg_index in range(self.controller.num_segments) ])
178-
self.combo_seg.currentIndexChanged.connect(self.refresh)
184+
self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ])
185+
self.combo_seg.currentIndexChanged.connect(self._qt_change_segment)
179186
add_stretch_to_qtoolbar(tb)
180187
self.lasso_but = QT.QPushButton("select", checkable = True)
181188
tb.addWidget(self.lasso_but)
@@ -235,6 +242,12 @@ def _qt_initialize_plot(self):
235242
def _qt_on_spike_selection_changed(self):
236243
self.refresh()
237244

245+
def _qt_change_segment(self):
246+
segment_index = self.combo_seg.currentIndex()
247+
self.controller.set_time(segment_index=segment_index)
248+
self.refresh()
249+
self.notify_time_info_updated()
250+
238251
def _qt_refresh(self):
239252
from .myqt import QT
240253
import pyqtgraph as pg
@@ -246,13 +259,18 @@ def _qt_refresh(self):
246259
if self.spike_data is None:
247260
return
248261

262+
segment_index = self.controller.get_time()[1]
263+
# Update combo_seg if it doesn't match the current segment index
264+
if self.combo_seg.currentIndex() != segment_index:
265+
self.combo_seg.setCurrentIndex(segment_index)
266+
249267
max_count = 1
250268
all_inds = []
251269
for unit_id in self.controller.get_visible_unit_ids():
252270

253271
spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data(
254272
unit_id,
255-
seg_index=self.combo_seg.currentIndex()
273+
segment_index=segment_index
256274
)
257275

258276
# make a copy of the color
@@ -276,7 +294,7 @@ def _qt_refresh(self):
276294
y_range_plot_1 = self.plot.getViewBox().viewRange()
277295
self.viewBox2.setYRange(y_range_plot_1[1][0], y_range_plot_1[1][1], padding = 0.0)
278296

279-
spike_times, spike_data = self.get_selected_spikes_data(seg_index=self.combo_seg.currentIndex(), visible_inds=all_inds)
297+
spike_times, spike_data = self.get_selected_spikes_data(segment_index=self.combo_seg.currentIndex(), visible_inds=all_inds)
280298

281299
self.scatter_select.setData(spike_times, spike_data)
282300

@@ -296,8 +314,8 @@ def _qt_on_lasso_finished(self, points, shift_held=False):
296314
self.lasso.setData([], [])
297315
vertices = np.array(points)
298316

299-
seg_index = self.combo_seg.currentIndex()
300-
sl = self.controller.segment_slices[seg_index]
317+
segment_index = self.combo_seg.currentIndex()
318+
sl = self.controller.segment_slices[segment_index]
301319
spikes_in_seg = self.controller.spikes[sl]
302320

303321
# Create mask for visible units
@@ -315,16 +333,16 @@ def _qt_on_lasso_finished(self, points, shift_held=False):
315333
self.notify_spike_selection_changed()
316334
return
317335

318-
if self._lasso_vertices[seg_index] is None:
319-
self._lasso_vertices[seg_index] = []
336+
if self._lasso_vertices[segment_index] is None:
337+
self._lasso_vertices[segment_index] = []
320338

321339
if shift_held:
322340
# If shift is held, append the vertices to the current lasso vertices
323-
self._lasso_vertices[seg_index].append(vertices)
341+
self._lasso_vertices[segment_index].append(vertices)
324342
keep_already_selected = True
325343
else:
326344
# If shift is not held, clear the existing lasso vertices for this segment
327-
self._lasso_vertices[seg_index] = [vertices]
345+
self._lasso_vertices[segment_index] = [vertices]
328346
keep_already_selected = False
329347

330348
self.select_all_spikes_from_lasso(keep_already_selected=keep_already_selected)
@@ -445,11 +463,13 @@ def _panel_refresh(self):
445463
ys = []
446464
colors = []
447465

466+
segment_index = self.controller.get_time()[1]
467+
448468
visible_unit_ids = self.controller.get_visible_unit_ids()
449469
for unit_id in visible_unit_ids:
450470
spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data(
451471
unit_id,
452-
seg_index=self.segment_index
472+
segment_index=segment_index
453473
)
454474
color = self.get_unit_color(unit_id)
455475
xs.extend(spike_times)
@@ -504,9 +524,12 @@ def _panel_on_select_button(self, event):
504524
def _panel_change_segment(self, event):
505525
self._current_selected = 0
506526
self.segment_index = int(self.segment_selector.value.split()[-1])
507-
time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency
508-
self.scatter_fig.x_range.end = time_max
527+
self.controller.set_time(segment_index=self.segment_index)
528+
t_start, t_end = self.controller.get_t_start_t_end()
529+
self.scatter_fig.x_range.start = t_start
530+
self.scatter_fig.x_range.end = t_end
509531
self.refresh()
532+
self.notify_time_info_updated()
510533

511534
def _on_panel_selection_geometry(self, event):
512535
"""
@@ -524,16 +547,16 @@ def _on_panel_selection_geometry(self, event):
524547
return
525548

526549
# Append the current polygon to the lasso vertices if shift is held
527-
seg_index = self.segment_index
528-
if self._lasso_vertices[seg_index] is None:
529-
self._lasso_vertices[seg_index] = []
550+
segment_index = self.segment_index
551+
if self._lasso_vertices[segment_index] is None:
552+
self._lasso_vertices[segment_index] = []
530553
if len(selected) > self._current_selected:
531554
self._current_selected = len(selected)
532555
# Store the current polygon for the current segment
533-
self._lasso_vertices[seg_index].append(polygon)
556+
self._lasso_vertices[segment_index].append(polygon)
534557
keep_already_selected = True
535558
else:
536-
self._lasso_vertices[seg_index] = [polygon]
559+
self._lasso_vertices[segment_index] = [polygon]
537560
keep_already_selected = False
538561

539562
self.select_all_spikes_from_lasso(keep_already_selected)

spikeinterface_gui/controller.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_default_main_settings = dict(
2626
max_visible_units=10,
2727
color_mode='color_by_unit',
28+
use_times=False
2829
)
2930

3031
from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties
@@ -264,7 +265,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
264265

265266
# self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict")
266267
seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1))
267-
self.segment_slices = {seg_index: slice(seg_limits[seg_index], seg_limits[seg_index + 1]) for seg_index in range(num_seg)}
268+
self.segment_slices = {segment_index: slice(seg_limits[segment_index], seg_limits[segment_index + 1]) for segment_index in range(num_seg)}
268269

269270
spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False)
270271
self.final_spike_samples = [segment_spike_vector[-1][0] for segment_spike_vector in spike_vector2]
@@ -275,7 +276,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
275276
spike_per_seg = [s.size for s in spike_vector2]
276277
# dict[unit_id] -> all indices for this unit across segments
277278
self._spike_index_by_units = {}
278-
# dict[seg_index][unit_id] -> all indices for this unit for one segment
279+
# dict[segment_index][unit_id] -> all indices for this unit for one segment
279280
self._spike_index_by_segment_and_units = spike_indices_abs
280281
for unit_id in unit_ids:
281282
inds = []
@@ -302,10 +303,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
302303
self.displayed_unit_properties = displayed_unit_properties
303304

304305
# set default time info
305-
self.time_info = dict(
306-
time_by_seg=np.array([0] * self.num_segments, dtype="float64"),
307-
segment_index=0
308-
)
306+
self.update_time_info()
309307

310308
self.curation = curation
311309
# TODO: Reload the dictionary if it already exists
@@ -401,10 +399,10 @@ def get_time(self):
401399
"""
402400
Returns selected time and segment index
403401
"""
404-
seg_index = self.time_info['segment_index']
402+
segment_index = self.time_info['segment_index']
405403
time_by_seg = self.time_info['time_by_seg']
406-
time = time_by_seg[seg_index]
407-
return time, seg_index
404+
time = time_by_seg[segment_index]
405+
return time, segment_index
408406

409407
def set_time(self, time=None, segment_index=None):
410408
"""
@@ -418,7 +416,49 @@ def set_time(self, time=None, segment_index=None):
418416
segment_index = self.time_info['segment_index']
419417
if time is not None:
420418
self.time_info['time_by_seg'][segment_index] = time
421-
419+
420+
def update_time_info(self):
421+
# set default time info
422+
if self.main_settings["use_times"] and self.analyzer.has_recording():
423+
self.time_info = dict(
424+
time_by_seg=np.array(
425+
[
426+
self.analyzer.recording.get_start_time(segment_index) for segment_index in range(self.num_segments)
427+
],
428+
dtype="float64"),
429+
segment_index=0
430+
)
431+
else:
432+
self.time_info = dict(
433+
time_by_seg=np.array([0] * self.num_segments, dtype="float64"),
434+
segment_index=0
435+
)
436+
437+
def get_t_start_t_stop(self):
438+
segment_index = self.time_info["segment_index"]
439+
if self.main_settings["use_times"] and self.analyzer.has_recording():
440+
t_start = self.analyzer.recording.get_start_time(segment_index=segment_index)
441+
t_stop = self.analyzer.recording.get_end_time(segment_index=segment_index)
442+
return t_start, t_stop
443+
else:
444+
return 0, self.get_num_samples(segment_index) / self.sampling_frequency
445+
446+
def sample_index_to_time(self, sample_index):
447+
segment_index = self.time_info["segment_index"]
448+
if self.main_settings["use_times"] and self.analyzer.has_recording():
449+
time = self.analyzer.recording.sample_index_to_time(sample_index, segment_index=segment_index)
450+
return time
451+
else:
452+
return sample_index / self.sampling_frequency
453+
454+
def time_to_sample_index(self, time):
455+
segment_index = self.time_info["segment_index"]
456+
if self.main_settings["use_times"] and self.analyzer.has_recording():
457+
time = self.analyzer.recording.time_to_sample_index(time, segment_index=segment_index)
458+
return time
459+
else:
460+
return int(time * self.sampling_frequency)
461+
422462
def get_information_txt(self):
423463
nseg = self.analyzer.get_num_segments()
424464
nchan = self.analyzer.get_num_channels()
@@ -552,13 +592,13 @@ def set_indices_spike_selected(self, inds):
552592
sample_index = self.spikes['sample_index'][self._spike_selected_indices[0]]
553593
self.set_time(time=sample_index / self.sampling_frequency, segment_index=segment_index)
554594

555-
def get_spike_indices(self, unit_id, seg_index=None):
556-
if seg_index is None:
595+
def get_spike_indices(self, unit_id, segment_index=None):
596+
if segment_index is None:
557597
# dict[unit_id] -> all indices for this unit across segments
558598
return self._spike_index_by_units[unit_id]
559599
else:
560-
# dict[seg_index][unit_id] -> all indices for this unit for one segment
561-
return self._spike_index_by_segment_and_units[seg_index][unit_id]
600+
# dict[segment_index][unit_id] -> all indices for this unit for one segment
601+
return self._spike_index_by_segment_and_units[segment_index][unit_id]
562602

563603
def get_num_samples(self, segment_index):
564604
return self.analyzer.get_num_samples(segment_index=segment_index)
@@ -838,7 +878,7 @@ def make_manual_split_if_possible(self, unit_id):
838878
indices = self.get_indices_spike_selected()
839879
if len(indices) == 0:
840880
return False
841-
spike_inds = self.get_spike_indices(unit_id, seg_index=None)
881+
spike_inds = self.get_spike_indices(unit_id, segment_index=None)
842882
if not np.all(np.isin(indices, spike_inds)):
843883
return False
844884

spikeinterface_gui/curationview.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def unsplit(self):
5353
def select_and_notify_split(self, split_unit_id):
5454
self.controller.set_visible_unit_ids([split_unit_id])
5555
self.notify_unit_visibility_changed()
56-
spike_inds = self.controller.get_spike_indices(split_unit_id, seg_index=None)
56+
spike_inds = self.controller.get_spike_indices(split_unit_id, segment_index=None)
5757
active_split = [s for s in self.controller.curation_data['splits'] if s['unit_id'] == split_unit_id][0]
5858
split_indices = active_split['indices'][0]
5959
self.controller.set_indices_spike_selected(spike_inds[split_indices])

0 commit comments

Comments
 (0)