From 69e29a65a3136aa0affd032c8702a20b8473419e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 3 Oct 2025 17:52:28 +0200 Subject: [PATCH 1/8] Add main_setting to use recorting times (and other time-related fixes) --- spikeinterface_gui/backend_qt.py | 16 ++- spikeinterface_gui/basescatterview.py | 73 +++++++---- spikeinterface_gui/controller.py | 70 ++++++++--- spikeinterface_gui/curationview.py | 2 +- spikeinterface_gui/mainsettingsview.py | 12 +- spikeinterface_gui/rateview.py | 55 ++++++--- .../tests/test_mainwindow_qt.py | 7 ++ spikeinterface_gui/tracemapview.py | 44 ++++--- spikeinterface_gui/traceview.py | 114 ++++++++++-------- spikeinterface_gui/utils_qt.py | 6 +- spikeinterface_gui/view_base.py | 15 +++ 11 files changed, 279 insertions(+), 135 deletions(-) diff --git a/spikeinterface_gui/backend_qt.py b/spikeinterface_gui/backend_qt.py index b9f8986..648bf0a 100644 --- a/spikeinterface_gui/backend_qt.py +++ b/spikeinterface_gui/backend_qt.py @@ -19,6 +19,7 @@ class SignalNotifier(QT.QObject): channel_visibility_changed = QT.pyqtSignal() manual_curation_updated = QT.pyqtSignal() time_info_updated = QT.pyqtSignal() + use_times_updated = QT.pyqtSignal() unit_color_changed = QT.pyqtSignal() def __init__(self, parent=None, view=None): @@ -40,6 +41,9 @@ def notify_manual_curation_updated(self): def notify_time_info_updated(self): self.time_info_updated.emit() + def notify_use_times_updated(self): + self.use_times_updated.emit() + def notify_unit_color_changed(self): self.unit_color_changed.emit() @@ -63,6 +67,7 @@ def connect_view(self, view): view.notifier.channel_visibility_changed.connect(self.on_channel_visibility_changed) view.notifier.manual_curation_updated.connect(self.on_manual_curation_updated) view.notifier.time_info_updated.connect(self.on_time_info_updated) + view.notifier.use_times_updated.connect(self.on_use_times_updated) view.notifier.unit_color_changed.connect(self.on_unit_color_changed) def on_spike_selection_changed(self): @@ -110,7 +115,16 @@ def on_time_info_updated(self): # do not refresh it self continue view.on_time_info_updated() - + + def on_use_times_updated(self): + if not self._active: + return + for view in self.controller.views: + if view.qt_widget == self.sender().parent(): + # do not refresh it self + continue + view.on_use_times_updated() + def on_unit_color_changed(self): if not self._active: return diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 22a4939..96398c0 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -36,9 +36,10 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) - def get_unit_data(self, unit_id, seg_index=0): - inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index) - spike_times = self.controller.spikes["sample_index"][inds] / self.controller.sampling_frequency + def get_unit_data(self, unit_id, segment_index=0): + inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index) + spike_indices = self.controller.spikes["sample_index"][inds] + spike_times = self.controller.sample_index_to_time(spike_indices) spike_data = self.spike_data[inds] ptp = np.ptp(spike_data) hist_min, hist_max = [np.min(spike_data) - 0.2 * ptp, np.max(spike_data) + 0.2 * ptp] @@ -53,8 +54,8 @@ def get_unit_data(self, unit_id, seg_index=0): return spike_times, spike_data, hist_count, hist_bins, inds - def get_selected_spikes_data(self, seg_index=0, visible_inds=None): - sl = self.controller.segment_slices[seg_index] + def get_selected_spikes_data(self, segment_index=0, visible_inds=None): + sl = self.controller.segment_slices[segment_index] spikes_in_seg = self.controller.spikes[sl] selected_indices = self.controller.get_indices_spike_selected() if visible_inds is not None: @@ -85,7 +86,7 @@ def select_all_spikes_from_lasso(self, keep_already_selected=False): for segment_index, vertices in self._lasso_vertices.items(): if vertices is None: continue - spike_inds = self.controller.get_spike_indices(visible_unit_id, seg_index=segment_index) + spike_inds = self.controller.get_spike_indices(visible_unit_id, segment_index=segment_index) spike_times = self.controller.spikes["sample_index"][spike_inds] / fs spike_data = self.spike_data[spike_inds] @@ -119,7 +120,7 @@ def split(self): if self.controller.num_segments > 1: # check that lasso vertices are defined for all segments - if not all(self._lasso_vertices[seg_index] is not None for seg_index in range(self.controller.num_segments)): + if not all(self._lasso_vertices[segment_index] is not None for segment_index in range(self.controller.num_segments)): # Use the new continue_from_user pattern self.continue_from_user( "Not all segments have lasso selection. " @@ -163,6 +164,12 @@ def on_unit_visibility_changed(self): self._current_selected = self.controller.get_indices_spike_selected().size self.refresh() + def on_time_info_updated(self): + return self.refresh() + + def on_use_times_updated(self): + return self.refresh() + ## QT zone ## def _qt_make_layout(self): from .myqt import QT @@ -174,8 +181,8 @@ def _qt_make_layout(self): tb = self.qt_widget.view_toolbar self.combo_seg = QT.QComboBox() tb.addWidget(self.combo_seg) - self.combo_seg.addItems([ f'Segment {seg_index}' for seg_index in range(self.controller.num_segments) ]) - self.combo_seg.currentIndexChanged.connect(self.refresh) + self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ]) + self.combo_seg.currentIndexChanged.connect(self._qt_change_segment) add_stretch_to_qtoolbar(tb) self.lasso_but = QT.QPushButton("select", checkable = True) tb.addWidget(self.lasso_but) @@ -235,6 +242,12 @@ def _qt_initialize_plot(self): def _qt_on_spike_selection_changed(self): self.refresh() + def _qt_change_segment(self): + segment_index = self.combo_seg.currentIndex() + self.controller.set_time(segment_index=segment_index) + self.refresh() + self.notify_time_info_updated() + def _qt_refresh(self): from .myqt import QT import pyqtgraph as pg @@ -246,13 +259,18 @@ def _qt_refresh(self): if self.spike_data is None: return + segment_index = self.controller.get_time()[1] + # Update combo_seg if it doesn't match the current segment index + if self.combo_seg.currentIndex() != segment_index: + self.combo_seg.setCurrentIndex(segment_index) + max_count = 1 all_inds = [] for unit_id in self.controller.get_visible_unit_ids(): spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data( unit_id, - seg_index=self.combo_seg.currentIndex() + segment_index=segment_index ) # make a copy of the color @@ -276,7 +294,7 @@ def _qt_refresh(self): y_range_plot_1 = self.plot.getViewBox().viewRange() self.viewBox2.setYRange(y_range_plot_1[1][0], y_range_plot_1[1][1], padding = 0.0) - spike_times, spike_data = self.get_selected_spikes_data(seg_index=self.combo_seg.currentIndex(), visible_inds=all_inds) + spike_times, spike_data = self.get_selected_spikes_data(segment_index=self.combo_seg.currentIndex(), visible_inds=all_inds) self.scatter_select.setData(spike_times, spike_data) @@ -296,8 +314,8 @@ def _qt_on_lasso_finished(self, points, shift_held=False): self.lasso.setData([], []) vertices = np.array(points) - seg_index = self.combo_seg.currentIndex() - sl = self.controller.segment_slices[seg_index] + segment_index = self.combo_seg.currentIndex() + sl = self.controller.segment_slices[segment_index] spikes_in_seg = self.controller.spikes[sl] # Create mask for visible units @@ -315,16 +333,16 @@ def _qt_on_lasso_finished(self, points, shift_held=False): self.notify_spike_selection_changed() return - if self._lasso_vertices[seg_index] is None: - self._lasso_vertices[seg_index] = [] + if self._lasso_vertices[segment_index] is None: + self._lasso_vertices[segment_index] = [] if shift_held: # If shift is held, append the vertices to the current lasso vertices - self._lasso_vertices[seg_index].append(vertices) + self._lasso_vertices[segment_index].append(vertices) keep_already_selected = True else: # If shift is not held, clear the existing lasso vertices for this segment - self._lasso_vertices[seg_index] = [vertices] + self._lasso_vertices[segment_index] = [vertices] keep_already_selected = False self.select_all_spikes_from_lasso(keep_already_selected=keep_already_selected) @@ -445,11 +463,13 @@ def _panel_refresh(self): ys = [] colors = [] + segment_index = self.controller.get_time()[1] + visible_unit_ids = self.controller.get_visible_unit_ids() for unit_id in visible_unit_ids: spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data( unit_id, - seg_index=self.segment_index + segment_index=segment_index ) color = self.get_unit_color(unit_id) xs.extend(spike_times) @@ -504,9 +524,12 @@ def _panel_on_select_button(self, event): def _panel_change_segment(self, event): self._current_selected = 0 self.segment_index = int(self.segment_selector.value.split()[-1]) - time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency - self.scatter_fig.x_range.end = time_max + self.controller.set_time(segment_index=self.segment_index) + t_start, t_end = self.controller.get_t_start_t_end() + self.scatter_fig.x_range.start = t_start + self.scatter_fig.x_range.end = t_end self.refresh() + self.notify_time_info_updated() def _on_panel_selection_geometry(self, event): """ @@ -524,16 +547,16 @@ def _on_panel_selection_geometry(self, event): return # Append the current polygon to the lasso vertices if shift is held - seg_index = self.segment_index - if self._lasso_vertices[seg_index] is None: - self._lasso_vertices[seg_index] = [] + segment_index = self.segment_index + if self._lasso_vertices[segment_index] is None: + self._lasso_vertices[segment_index] = [] if len(selected) > self._current_selected: self._current_selected = len(selected) # Store the current polygon for the current segment - self._lasso_vertices[seg_index].append(polygon) + self._lasso_vertices[segment_index].append(polygon) keep_already_selected = True else: - self._lasso_vertices[seg_index] = [polygon] + self._lasso_vertices[segment_index] = [polygon] keep_already_selected = False self.select_all_spikes_from_lasso(keep_already_selected) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index c9c1dbe..cafeadf 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -25,6 +25,7 @@ _default_main_settings = dict( max_visible_units=10, color_mode='color_by_unit', + use_times=False ) from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties @@ -264,7 +265,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save # self.num_spikes = self.analyzer.sorting.count_num_spikes_per_unit(outputs="dict") seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1)) - self.segment_slices = {seg_index: slice(seg_limits[seg_index], seg_limits[seg_index + 1]) for seg_index in range(num_seg)} + self.segment_slices = {segment_index: slice(seg_limits[segment_index], seg_limits[segment_index + 1]) for segment_index in range(num_seg)} spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False) self.final_spike_samples = [segment_spike_vector[-1][0] for segment_spike_vector in spike_vector2] @@ -275,7 +276,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save spike_per_seg = [s.size for s in spike_vector2] # dict[unit_id] -> all indices for this unit across segments self._spike_index_by_units = {} - # dict[seg_index][unit_id] -> all indices for this unit for one segment + # dict[segment_index][unit_id] -> all indices for this unit for one segment self._spike_index_by_segment_and_units = spike_indices_abs for unit_id in unit_ids: inds = [] @@ -302,10 +303,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self.displayed_unit_properties = displayed_unit_properties # set default time info - self.time_info = dict( - time_by_seg=np.array([0] * self.num_segments, dtype="float64"), - segment_index=0 - ) + self.update_time_info() self.curation = curation # TODO: Reload the dictionary if it already exists @@ -401,10 +399,10 @@ def get_time(self): """ Returns selected time and segment index """ - seg_index = self.time_info['segment_index'] + segment_index = self.time_info['segment_index'] time_by_seg = self.time_info['time_by_seg'] - time = time_by_seg[seg_index] - return time, seg_index + time = time_by_seg[segment_index] + return time, segment_index def set_time(self, time=None, segment_index=None): """ @@ -418,7 +416,49 @@ def set_time(self, time=None, segment_index=None): segment_index = self.time_info['segment_index'] if time is not None: self.time_info['time_by_seg'][segment_index] = time - + + def update_time_info(self): + # set default time info + if self.main_settings["use_times"] and self.analyzer.has_recording(): + self.time_info = dict( + time_by_seg=np.array( + [ + self.analyzer.recording.get_start_time(segment_index) for segment_index in range(self.num_segments) + ], + dtype="float64"), + segment_index=0 + ) + else: + self.time_info = dict( + time_by_seg=np.array([0] * self.num_segments, dtype="float64"), + segment_index=0 + ) + + def get_t_start_t_stop(self): + segment_index = self.time_info["segment_index"] + if self.main_settings["use_times"] and self.analyzer.has_recording(): + t_start = self.analyzer.recording.get_start_time(segment_index=segment_index) + t_stop = self.analyzer.recording.get_end_time(segment_index=segment_index) + return t_start, t_stop + else: + return 0, self.get_num_samples(segment_index) / self.sampling_frequency + + def sample_index_to_time(self, sample_index): + segment_index = self.time_info["segment_index"] + if self.main_settings["use_times"] and self.analyzer.has_recording(): + time = self.analyzer.recording.sample_index_to_time(sample_index, segment_index=segment_index) + return time + else: + return sample_index / self.sampling_frequency + + def time_to_sample_index(self, time): + segment_index = self.time_info["segment_index"] + if self.main_settings["use_times"] and self.analyzer.has_recording(): + time = self.analyzer.recording.time_to_sample_index(time, segment_index=segment_index) + return time + else: + return int(time * self.sampling_frequency) + def get_information_txt(self): nseg = self.analyzer.get_num_segments() nchan = self.analyzer.get_num_channels() @@ -552,13 +592,13 @@ def set_indices_spike_selected(self, inds): sample_index = self.spikes['sample_index'][self._spike_selected_indices[0]] self.set_time(time=sample_index / self.sampling_frequency, segment_index=segment_index) - def get_spike_indices(self, unit_id, seg_index=None): - if seg_index is None: + def get_spike_indices(self, unit_id, segment_index=None): + if segment_index is None: # dict[unit_id] -> all indices for this unit across segments return self._spike_index_by_units[unit_id] else: - # dict[seg_index][unit_id] -> all indices for this unit for one segment - return self._spike_index_by_segment_and_units[seg_index][unit_id] + # dict[segment_index][unit_id] -> all indices for this unit for one segment + return self._spike_index_by_segment_and_units[segment_index][unit_id] def get_num_samples(self, segment_index): return self.analyzer.get_num_samples(segment_index=segment_index) @@ -838,7 +878,7 @@ def make_manual_split_if_possible(self, unit_id): indices = self.get_indices_spike_selected() if len(indices) == 0: return False - spike_inds = self.get_spike_indices(unit_id, seg_index=None) + spike_inds = self.get_spike_indices(unit_id, segment_index=None) if not np.all(np.isin(indices, spike_inds)): return False diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index 5164afa..148eae6 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -53,7 +53,7 @@ def unsplit(self): def select_and_notify_split(self, split_unit_id): self.controller.set_visible_unit_ids([split_unit_id]) self.notify_unit_visibility_changed() - spike_inds = self.controller.get_spike_indices(split_unit_id, seg_index=None) + spike_inds = self.controller.get_spike_indices(split_unit_id, segment_index=None) active_split = [s for s in self.controller.curation_data['splits'] if s['unit_id'] == split_unit_id][0] split_indices = active_split['indices'][0] self.controller.set_indices_spike_selected(spike_inds[split_indices]) diff --git a/spikeinterface_gui/mainsettingsview.py b/spikeinterface_gui/mainsettingsview.py index a66c3b5..e70568e 100644 --- a/spikeinterface_gui/mainsettingsview.py +++ b/spikeinterface_gui/mainsettingsview.py @@ -6,6 +6,7 @@ {'name': 'max_visible_units', 'type': 'int', 'value' : 10 }, {'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit', 'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']}, + {'name': 'use_times', 'type': 'bool', 'value': False} ] @@ -35,6 +36,11 @@ def on_change_color_mode(self): self.controller.refresh_colors() self.notify_unit_color_changed() + + def on_use_times(self): + self.controller.main_settings['use_times'] = self.main_settings['use_times'] + self.controller.update_time_info() + self.notify_use_times_updated() # for view in self.controller.views: # view.refresh() @@ -60,7 +66,7 @@ def _qt_make_layout(self): self.main_settings.param('max_visible_units').sigValueChanged.connect(self.on_max_visible_units_changed) self.main_settings.param('color_mode').sigValueChanged.connect(self.on_change_color_mode) - + self.main_settings.param('use_times').sigValueChanged.connect(self.on_use_times) def _qt_refresh(self): pass @@ -77,6 +83,7 @@ def _panel_make_layout(self): name=f"Main settings") self.main_settings._parameterized.param.watch(self._panel_on_max_visible_units_changed, 'max_visible_units') self.main_settings._parameterized.param.watch(self._panel_on_change_color_mode, 'color_mode') + self.main_settings._parameterized.param.watch(self._panel_on_use_times, 'use_times') self.layout = pn.Column(self.main_settings_layout, sizing_mode="stretch_both") def _panel_on_max_visible_units_changed(self, event): @@ -85,6 +92,9 @@ def _panel_on_max_visible_units_changed(self, event): def _panel_on_change_color_mode(self, event): self.on_change_color_mode() + def _panel_on_use_times(self, event): + self.on_use_times() + def _panel_refresh(self): pass diff --git a/spikeinterface_gui/rateview.py b/spikeinterface_gui/rateview.py index a6c016d..c1a204e 100644 --- a/spikeinterface_gui/rateview.py +++ b/spikeinterface_gui/rateview.py @@ -15,6 +15,13 @@ def __init__(self, controller=None, parent=None, backend="qt"): def _on_settings_changed(self): self.refresh() + def on_time_info_updated(self): + self.refresh() + + def on_use_times_updated(self): + print(f"Refreshing SpikeRateView") + self.refresh() + ## Qt ## def _qt_make_layout(self): @@ -26,8 +33,8 @@ def _qt_make_layout(self): tb = self.qt_widget.view_toolbar self.combo_seg = QT.QComboBox() tb.addWidget(self.combo_seg) - self.combo_seg.addItems([ f'Segment {seg_index}' for seg_index in range(self.controller.num_segments) ]) - self.combo_seg.currentIndexChanged.connect(self.refresh) + self.combo_seg.addItems([f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ]) + self.combo_seg.currentIndexChanged.connect(self._qt_change_segment) h = QT.QHBoxLayout() self.layout.addLayout(h) @@ -37,31 +44,41 @@ def _qt_make_layout(self): self.graphicsview.setCentralItem(self.plot) self.layout.addWidget(self.graphicsview) + def _qt_change_segment(self): + segment_index = self.combo_seg.currentIndex() + self.controller.set_time(segment_index=segment_index) + self.refresh() + self.notify_time_info_updated() + def _qt_refresh(self): import pyqtgraph as pg self.plot.clear() - seg_index = self.combo_seg.currentIndex() - + segment_index = self.controller.get_time()[1] + # Update combo_seg if it doesn't match the current segment index + if self.combo_seg.currentIndex() != segment_index: + self.combo_seg.setCurrentIndex(segment_index) + visible_unit_ids = self.controller.get_visible_unit_ids() sampling_frequency = self.controller.sampling_frequency total_frames = self.controller.final_spike_samples bins_s = self.settings['bin_s'] - num_bins = total_frames[seg_index] // int(sampling_frequency) // bins_s - + t_start, _ = self.controller.get_t_start_t_stop() + num_bins = total_frames[segment_index] // int(sampling_frequency) // bins_s + for r, unit_id in enumerate(visible_unit_ids): - spike_inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index) + spike_inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index) spikes = self.controller.spikes[spike_inds]['sample_index'] count, bins = np.histogram(spikes, bins=num_bins) color = self.get_unit_color(unit_id) curve = pg.PlotCurveItem( - (bins[1:]+bins[:-1])/(2*sampling_frequency), + (bins[1:]+bins[:-1])/(2*sampling_frequency) + t_start, count/bins_s, pen=pg.mkPen(color, width=2) ) @@ -107,12 +124,10 @@ def _panel_make_layout(self): self.is_warning_active = False def _panel_refresh(self): - import panel as pn - import bokeh.plotting as bpl - from bokeh.layouts import gridplot - from .utils_panel import _bg_color - - seg_index = self.segment_index + segment_index = self.controller.get_time()[1] + if segment_index != self.segment_index: + self.segment_index = segment_index + self.segment_selector.value = f"Segment {self.segment_index}" visible_unit_ids = self.controller.get_visible_unit_ids() @@ -120,14 +135,15 @@ def _panel_refresh(self): total_frames = self.controller.final_spike_samples bins_s = self.settings['bin_s'] - num_bins = total_frames[seg_index] // int(sampling_frequency) // bins_s + num_bins = total_frames[segment_index] // int(sampling_frequency) // bins_s + t_start, _ = self.controller.get_t_start_t_stop() # clear fig self.rate_fig.renderers = [] - + for unit_id in visible_unit_ids: - spike_inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index) + spike_inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index) spikes = self.controller.spikes[spike_inds]['sample_index'] count, bins = np.histogram(spikes, bins=num_bins) @@ -136,7 +152,7 @@ def _panel_refresh(self): color = self.get_unit_color(unit_id) line = self.rate_fig.line( - x=(bins[1:]+bins[:-1])/(2*sampling_frequency), + x=(bins[1:]+bins[:-1])/(2*sampling_frequency) + t_start, y=count/bins_s, color=color, line_width=2, @@ -146,8 +162,9 @@ def _panel_refresh(self): def _panel_change_segment(self, event): self.segment_index = int(self.segment_selector.value.split()[-1]) + self.controller.set_time(segment_index=self.segment_index) self.refresh() - + self.notify_time_info_updated() SpikeRateView._gui_help_txt = """ diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index d29f8ca..e547154 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -68,6 +68,13 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext yip=np.array([f"yip{i}" for i in range(n)]), ) + for segment_index in range(analyzer.get_num_segments()): + shift = (segment_index + 1) * 100 + analyzer.recording.set_times( + analyzer.recording.get_times(segment_index) + shift, + segment_index=segment_index + ) + win = run_mainwindow( analyzer, mode="desktop", diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index 31c04e3..efd0909 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -64,11 +64,7 @@ def make_color_lut(self): def get_data_in_chunk(self, t1, t2, segment_index): - t_start = 0.0 - sr = self.controller.sampling_frequency - - ind1 = max(0, int((t1 - t_start) * sr)) - ind2 = min(self.controller.get_num_samples(segment_index), int((t2 - t_start) * sr)) + ind1, ind2 = self.get_chunk_indices(t1, t2, segment_index) traces_chunk = self.controller.get_traces(segment_index=segment_index, start_frame=ind1, end_frame=ind2) @@ -83,7 +79,8 @@ def get_data_in_chunk(self, t1, t2, segment_index): if data_curves.dtype != "float32": data_curves = data_curves.astype("float32") - times_chunk = np.arange(traces_chunk.shape[0], dtype='float64')/self.controller.sampling_frequency+max(t1, 0) + t_start, _ = self.controller.get_t_start_t_stop() + times_chunk = np.arange(traces_chunk.shape[0], dtype='float64') / self.controller.sampling_frequency + max(t1, t_start) scatter_x = [] scatter_y = [] @@ -192,13 +189,14 @@ def _qt_seek(self, t): sr = self.controller.sampling_frequency self.scroll_time.valueChanged.disconnect(self._qt_on_scroll_time) - self.scroll_time.setValue(int(sr*t)) + value = self.controller.time_to_sample_index(t) + self.scroll_time.setValue(value) self.scroll_time.setPageStep(int(sr*xsize)) self.scroll_time.valueChanged.connect(self._qt_on_scroll_time) - seg_index = self.controller.get_time()[1] + segment_index = self.controller.get_time()[1] times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \ - self.get_data_in_chunk(t1, t2, seg_index) + self.get_data_in_chunk(t1, t2, segment_index) if self.color_limit is None: self.color_limit = np.max(np.abs(data_curves)) @@ -217,16 +215,22 @@ def _qt_seek(self, t): def _qt_on_time_info_updated(self): # Update segment and time slider range - time, seg_index = self.controller.get_time() - + time, segment_index = self.controller.get_time() + # Block auto refresh to avoid recursive calls self._block_auto_refresh_and_notify = True - self._qt_change_segment(seg_index) - self.timeseeker.seek(time) + self._qt_change_segment(segment_index) + self.timeseeker.seek(time) + self._block_auto_refresh_and_notify = False - # we need a refresh in panel because changing tab triggers a refresh + # we need refresh in QT because changing tab/docking/undocking doesn't trigger a refresh self.refresh() + def _qt_on_use_times_updated(self): + # Update time seeker + t_start, t_stop = self.controller.get_t_start_t_stop() + self.timeseeker.set_start_stop(t_start, t_stop) + ## Panel ## def _panel_make_layout(self): import panel as pn @@ -308,7 +312,7 @@ def _panel_make_layout(self): ) def _panel_refresh(self): - t, seg_index = self.controller.get_time() + t, segment_index = self.controller.get_time() xsize = self.xsize t1, t2 = t - xsize / 3.0, t + xsize * 2 / 3.0 @@ -318,7 +322,7 @@ def _panel_refresh(self): auto_scale = False times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \ - self.get_data_in_chunk(t1, t2, seg_index) + self.get_data_in_chunk(t1, t2, segment_index) if self.color_limit is None: self.color_limit = np.max(np.abs(data_curves)) @@ -349,8 +353,8 @@ def _panel_refresh(self): # TODO: if from a different unit, change unit visibility def _panel_on_tap(self, event): - seg_index = self.controller.get_time()[1] - ind_spike_nearest = self.find_nearest_spike(self.controller, event.x, seg_index) + segment_index = self.controller.get_time()[1] + ind_spike_nearest = self.find_nearest_spike(self.controller, event.x, segment_index) if ind_spike_nearest is not None: self.controller.set_indices_spike_selected([ind_spike_nearest]) self._panel_seek_with_selected_spike() @@ -376,10 +380,10 @@ def _panel_auto_scale(self, event): def _panel_on_time_info_updated(self): # Update segment and time slider range - time, seg_index = self.controller.get_time() + time, segment_index = self.controller.get_time() self._block_auto_refresh_and_notify = True - self._panel_change_segment(seg_index) + self._panel_change_segment(segment_index) # Update time slider value self.time_slider.value = time diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index c2c00cb..d5a35c4 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -24,7 +24,7 @@ def _qt_create_toolbar(self): #Segment selection self.combo_seg = QT.QComboBox() tb.addWidget(self.combo_seg) - self.combo_seg.addItems([ f'Segment {seg_index}' for seg_index in range(self.controller.num_segments) ]) + self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ]) self.combo_seg.currentIndexChanged.connect(self._qt_on_combo_seg_changed) add_stretch_to_qtoolbar(tb) @@ -84,20 +84,19 @@ def _qt_initialize_plot(self): self.offsets = None def _qt_update_scroll_limits(self): - seg_index = self.controller.get_time()[1] - length = self.controller.get_num_samples(seg_index) - t_start = 0. - t_stop = length/self.controller.sampling_frequency + segment_index = self.controller.get_time()[1] + length = self.controller.get_num_samples(segment_index) + t_start, t_stop = self.controller.get_t_start_t_stop() self.timeseeker.set_start_stop(t_start, t_stop, seek=False) self.scroll_time.setMinimum(0) - self.scroll_time.setMaximum(length) + self.scroll_time.setMaximum(length - 1) - def _qt_change_segment(self, seg_index): - #TODO: dirty because now seg_pos IS seg_index - self.controller.set_time(segment_index=seg_index) + def _qt_change_segment(self, segment_index): + #TODO: dirty because now seg_pos IS segment_index + self.controller.set_time(segment_index=segment_index) - if seg_index != self.combo_seg.currentIndex(): - self.combo_seg.setCurrentIndex(seg_index) + if segment_index != self.combo_seg.currentIndex(): + self.combo_seg.setCurrentIndex(segment_index) self._qt_update_scroll_limits() if not self._block_auto_refresh_and_notify: @@ -130,8 +129,8 @@ def _qt_xsize_zoom(self, xmove): self.spinbox_xsize.setValue(newsize) def _qt_on_scroll_time(self, val): - sr = self.controller.sampling_frequency - self.timeseeker.seek(val/sr) + time = self.controller.sample_index_to_time(val) + self.timeseeker.seek(time) def _qt_seek_with_selected_spike(self): ind_selected = self.controller.get_indices_spike_selected() @@ -140,11 +139,11 @@ def _qt_seek_with_selected_spike(self): if self.settings['auto_zoom_on_select'] and n_selected == 1: ind = ind_selected[0] peak_ind = self.controller.spikes[ind]['sample_index'] - seg_index = self.controller.spikes[ind]['segment_index'] - peak_time = peak_ind / self.controller.sampling_frequency + segment_index = self.controller.spikes[ind]['segment_index'] + peak_time = self.controller.sample_index_to_time(peak_ind) - if seg_index != self.controller.get_time()[1]: - self._qt_change_segment(seg_index) + if segment_index != self.controller.get_time()[1]: + self._qt_change_segment(segment_index) self.spinbox_xsize.sigValueChanged.disconnect(self._qt_on_xsize_changed) self.xsize = self.settings['spike_selection_xsize'] @@ -154,16 +153,30 @@ def _qt_seek_with_selected_spike(self): self.controller.set_time(time=peak_time) self.notify_time_info_updated() self.refresh() + + def get_chunk_indices(self, t1, t2, segment_index): + if self.controller.main_settings["use_times"]: + recording = self.controller.analyzer.recording + ind1, ind2 = recording.time_to_sample_index([t1, t2], segment_index=segment_index) + else: + t_start = 0.0 + sr = self.controller.sampling_frequency + ind1 = int((t1 - t_start) * sr) + ind2 = int((t2 - t_start) * sr) + + ind1 = max(0, ind1) + ind2 = min(self.controller.get_num_samples(segment_index), ind2) + return ind1, ind2 ## panel ## def _panel_create_toolbar(self): import panel as pn - seg_index = self.controller.get_time()[1] + segment_index = self.controller.get_time()[1] self.segment_selector = pn.widgets.Select( name="", options=[f"Segment {i}" for i in range(self.controller.num_segments)], - value=f"Segment {seg_index}", + value=f"Segment {segment_index}", ) # Window size control @@ -189,10 +202,9 @@ def _panel_create_toolbar(self): ) # Time slider - seg_index = self.controller.get_time()[1] - length = self.controller.get_num_samples(seg_index) - t_start = 0 - t_stop = length / self.controller.sampling_frequency + segment_index = self.controller.get_time()[1] + # update with controller.get_t_start/get_t_end + t_start, t_stop = self.controller.get_t_start_t_stop() self.time_slider = pn.widgets.FloatSlider(name="Time (s)", start=t_start, end=t_stop, value=0, step=0.1, value_throttled=0, sizing_mode="stretch_width") self.time_slider.param.watch(self._panel_on_time_slider_changed, "value_throttled") @@ -201,18 +213,17 @@ def _panel_auto_scale(self, event): self.auto_scale() def _panel_on_segment_changed(self, event): - seg_index = int(event.new.split()[-1]) - self._panel_change_segment(seg_index) + segment_index = int(event.new.split()[-1]) + self._panel_change_segment(segment_index) - def _panel_change_segment(self, seg_index): - self.segment_selector.value = f"Segment {seg_index}" + def _panel_change_segment(self, segment_index): + self.segment_selector.value = f"Segment {segment_index}" # Update time slider range - length = self.controller.get_num_samples(seg_index) - t_stop = length / self.controller.sampling_frequency - self.time_slider.start = 0 + self.controller.set_time(segment_index=segment_index) + t_start, t_stop = self.controller.get_t_start_t_stop() + self.time_slider.start = t_start self.time_slider.end = t_stop - self.controller.set_time(segment_index=seg_index) if not self._block_auto_refresh_and_notify: self.refresh() self.notify_time_info_updated() @@ -236,11 +247,11 @@ def _panel_seek_with_selected_spike(self): if self.settings['auto_zoom_on_select'] and n_selected == 1: ind = ind_selected[0] peak_ind = self.controller.spikes[ind]["sample_index"] - seg_index = self.controller.spikes[ind]["segment_index"] - peak_time = peak_ind / self.controller.sampling_frequency + segment_index = self.controller.spikes[ind]["segment_index"] + peak_time = self.controller.sample_index_to_time(peak_ind) - if seg_index != self.controller.get_time()[1]: - self._panel_change_segment(seg_index) + if segment_index != self.controller.get_time()[1]: + self._panel_change_segment(segment_index) # block callbacks self._block_auto_refresh_and_notify = True @@ -309,11 +320,7 @@ def apply_gain_zoom(self, factor_ratio): self.refresh() def get_data_in_chunk(self, t1, t2, segment_index): - t_start = 0.0 - sr = self.controller.sampling_frequency - - ind1 = max(0, int((t1 - t_start) * sr)) - ind2 = min(self.controller.get_num_samples(segment_index), int((t2 - t_start) * sr)) + ind1, ind2 = self.get_chunk_indices(t1, t2, segment_index) traces_chunk = self.controller.get_traces(segment_index=segment_index, start_frame=ind1, end_frame=ind2) @@ -337,7 +344,8 @@ def get_data_in_chunk(self, t1, t2, segment_index): data_curves *= gains[:, None] data_curves += offsets[:, None] - times_chunk = np.arange(traces_chunk.shape[0], dtype='float64')/self.controller.sampling_frequency+max(t1, 0) + t_start, _ = self.controller.get_t_start_t_stop() + times_chunk = np.arange(traces_chunk.shape[0], dtype='float64') / self.controller.sampling_frequency + max(t1, t_start) scatter_x = [] scatter_y = [] @@ -465,7 +473,8 @@ def _qt_seek(self, t): sr = self.controller.sampling_frequency self.scroll_time.valueChanged.disconnect(self._qt_on_scroll_time) - self.scroll_time.setValue(int(sr*t)) + value = self.controller.time_to_sample_index(t) + self.scroll_time.setValue(value) self.scroll_time.setPageStep(int(sr*xsize)) self.scroll_time.valueChanged.connect(self._qt_on_scroll_time) @@ -476,7 +485,6 @@ def _qt_seek(self, t): connect = np.ones(data_curves.shape, dtype='bool') connect[:, -1] = 0 - times_chunk_tile = np.tile(times_chunk, visible_channel_inds.size) self.signals_curve.setData(times_chunk_tile, data_curves.flatten(), connect=connect.flatten()) @@ -493,22 +501,26 @@ def _qt_seek(self, t): self.channel_labels[chan_ind].show() #ranges - self.plot.setXRange( t1, t2, padding = 0.0) - self.plot.setYRange(-.5, visible_channel_inds.size-.5, padding = 0.0) + self.plot.setXRange(t1, t2, padding=0.0) + self.plot.setYRange(-.5, visible_channel_inds.size-.5, padding=0.0) def _qt_on_time_info_updated(self): # Update segment and time slider range - time, seg_index = self.controller.get_time() + time, segment_index = self.controller.get_time() # Block auto refresh to avoid recursive calls self._block_auto_refresh_and_notify = True - self._qt_change_segment(seg_index) + self._qt_change_segment(segment_index) self.timeseeker.seek(time) self._block_auto_refresh_and_notify = False # we need refresh in QT because changing tab/docking/undocking doesn't trigger a refresh self.refresh() + def _qt_on_use_times_updated(self): + # Update time seeker + t_start, t_stop = self.controller.get_t_start_t_stop() + self.timeseeker.set_start_stop(t_start, t_stop) ## panel ## def _panel_make_layout(self): @@ -570,7 +582,7 @@ def _panel_make_layout(self): def _panel_refresh(self): - t, seg_index = self.controller.get_time() + t, segment_index = self.controller.get_time() xsize = self.xsize t1, t2 = t - xsize / 3.0, t + xsize * 2 / 3.0 @@ -593,7 +605,7 @@ def _panel_refresh(self): self.figure.x_range.end = t2 else: times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \ - self.get_data_in_chunk(t1, t2, seg_index) + self.get_data_in_chunk(t1, t2, segment_index) self.signal_source.data.update({ "xs": [times_chunk]*n, @@ -630,9 +642,9 @@ def _panel_gain_zoom(self, event): def _panel_on_time_info_updated(self): # Update segment and time slider range - time, seg_index = self.controller.get_time() + time, segment_index = self.controller.get_time() self._block_auto_refresh = True - self._panel_change_segment(seg_index) + self._panel_change_segment(segment_index) # Update time slider value self.time_slider.value = time self._block_auto_refresh = False diff --git a/spikeinterface_gui/utils_qt.py b/spikeinterface_gui/utils_qt.py index a815adc..8c09162 100644 --- a/spikeinterface_gui/utils_qt.py +++ b/spikeinterface_gui/utils_qt.py @@ -277,7 +277,8 @@ def __init__(self, parent = None, show_slider = True, show_spinbox = True) : self.set_start_stop(0., 10.) def set_start_stop(self, t_start, t_stop, seek = True): - if np.isnan(t_start) or np.isnan(t_stop): return + if np.isnan(t_start) or np.isnan(t_stop): + return assert t_stop>t_start self.t_start = t_start self.t_stop = t_stop @@ -298,7 +299,8 @@ def spinbox_changed(self, val): def seek(self, t, set_slider = True, set_spinbox = True, emit = True): self.t = t - + # print(f"Seeking {t} between {self.t_start} and {self.t_stop}") + if self.slider is not None and set_slider: self.slider.valueChanged.disconnect(self.slider_changed) pos = int((self.t - self.t_start)/(self.t_stop - self.t_start)*1000.) diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py index b3624a1..d96a095 100644 --- a/spikeinterface_gui/view_base.py +++ b/spikeinterface_gui/view_base.py @@ -60,6 +60,9 @@ def notify_manual_curation_updated(self): def notify_time_info_updated(self): self.notifier.notify_time_info_updated() + def notify_use_times_updated(self): + self.notifier.notify_use_times_updated() + def notify_active_view_updated(self): # this is used for panel if self.backend == "panel": @@ -209,6 +212,12 @@ def on_time_info_updated(self): elif self.backend == "panel": self._panel_on_time_info_updated() + def on_use_times_updated(self): + if self.backend == "qt": + self._qt_on_use_times_updated() + elif self.backend == "panel": + self._panel_on_use_times_updated() + def on_unit_color_changed(self): if self.backend == "qt": self._qt_on_unit_color_changed() @@ -240,6 +249,9 @@ def _qt_on_manual_curation_updated(self): def _qt_on_time_info_updated(self): pass + def _qt_on_use_times_updated(self): + pass + def _qt_on_unit_color_changed(self): self.refresh() @@ -287,6 +299,9 @@ def _panel_on_manual_curation_updated(self): def _panel_on_time_info_updated(self): pass + def _panel_use_times_updated(self): + pass + def _panel_on_unit_color_changed(self): self.refresh() From c9c76fe25fd0a5317619571aeddc1b27d36e2776 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 28 Oct 2025 17:41:56 +0100 Subject: [PATCH 2/8] Fix gaps --- .../tests/test_mainwindow_panel.py | 12 ++++++++++++ spikeinterface_gui/tests/test_mainwindow_qt.py | 7 ++++++- spikeinterface_gui/tracemapview.py | 11 ++++++++++- spikeinterface_gui/traceview.py | 16 ++++++++++++++-- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/spikeinterface_gui/tests/test_mainwindow_panel.py b/spikeinterface_gui/tests/test_mainwindow_panel.py index 6d18570..083ae8e 100644 --- a/spikeinterface_gui/tests/test_mainwindow_panel.py +++ b/spikeinterface_gui/tests/test_mainwindow_panel.py @@ -61,6 +61,18 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext ) win = None + for segment_index in range(analyzer.get_num_segments()): + shift = (segment_index + 1) * 100 + # add a gap to times + gap = 5 + times = analyzer.recording.get_times(segment_index) + times = times + shift + times[len(times)//2:] += gap # add a gap in the middle + analyzer.recording.set_times( + times, + segment_index=segment_index + ) + win = run_mainwindow( analyzer, mode="web", diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index e547154..ec53ba1 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -70,8 +70,13 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext for segment_index in range(analyzer.get_num_segments()): shift = (segment_index + 1) * 100 + # add a gap to times + gap = 1 + times = analyzer.recording.get_times(segment_index) + times = times + shift + times[len(times)//2:] += gap # add a gap in the middle analyzer.recording.set_times( - analyzer.recording.get_times(segment_index) + shift, + times, segment_index=segment_index ) diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index efd0909..404d1df 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -80,7 +80,11 @@ def get_data_in_chunk(self, t1, t2, segment_index): data_curves = data_curves.astype("float32") t_start, _ = self.controller.get_t_start_t_stop() - times_chunk = np.arange(traces_chunk.shape[0], dtype='float64') / self.controller.sampling_frequency + max(t1, t_start) + if self.controller.main_settings["use_times"]: + recording = self.controller.analyzer.recording + times_chunk = recording.get_times(segment_index=segment_index)[ind1:ind2] + else: + times_chunk = np.arange(traces_chunk.shape[0], dtype='float64') / self.controller.sampling_frequency + max(t1, t_start) scatter_x = [] scatter_y = [] @@ -197,6 +201,11 @@ def _qt_seek(self, t): segment_index = self.controller.get_time()[1] times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \ self.get_data_in_chunk(t1, t2, segment_index) + + if times_chunk.size == 0: + self.image.hide() + self.scatter.clear() + return if self.color_limit is None: self.color_limit = np.max(np.abs(data_curves)) diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index d5a35c4..e01b566 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -321,6 +321,9 @@ def apply_gain_zoom(self, factor_ratio): def get_data_in_chunk(self, t1, t2, segment_index): ind1, ind2 = self.get_chunk_indices(t1, t2, segment_index) + # handle blank spots + if ind1 == ind2: + return np.array([]), np.array([[]]), [], [], [], [] traces_chunk = self.controller.get_traces(segment_index=segment_index, start_frame=ind1, end_frame=ind2) @@ -345,7 +348,11 @@ def get_data_in_chunk(self, t1, t2, segment_index): data_curves += offsets[:, None] t_start, _ = self.controller.get_t_start_t_stop() - times_chunk = np.arange(traces_chunk.shape[0], dtype='float64') / self.controller.sampling_frequency + max(t1, t_start) + if self.controller.main_settings["use_times"]: + recording = self.controller.analyzer.recording + times_chunk = recording.get_times(segment_index=segment_index)[ind1:ind2] + else: + times_chunk = np.arange(traces_chunk.shape[0], dtype='float64') / self.controller.sampling_frequency + max(t1, t_start) scatter_x = [] scatter_y = [] @@ -355,7 +362,6 @@ def get_data_in_chunk(self, t1, t2, segment_index): global_to_local_chan_inds = np.zeros(self.controller.channel_ids.size, dtype='int64') global_to_local_chan_inds[visible_channel_inds] = np.arange(visible_channel_inds.size, dtype='int64') - for unit_index, unit_id in self.controller.iter_visible_units(): inds = np.flatnonzero(spikes_chunk["unit_index"] == unit_index) @@ -482,6 +488,12 @@ def _qt_seek(self, t): times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \ self.get_data_in_chunk(t1, t2, self.controller.get_time()[1]) + + if times_chunk.size == 0: + self.signals_curve.setData([], []) + self.scatter.setData(x=[], y=[], brush=[]) + return + connect = np.ones(data_curves.shape, dtype='bool') connect[:, -1] = 0 From 4c1208bb8372de278bc2016e39cce7c6faee3c8d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 30 Oct 2025 13:11:03 +0100 Subject: [PATCH 3/8] Fix lasso selection and split shortcut --- spikeinterface_gui/backend_qt.py | 1 - spikeinterface_gui/basescatterview.py | 11 ++++------ spikeinterface_gui/controller.py | 2 +- spikeinterface_gui/curationview.py | 4 ---- spikeinterface_gui/myqt.py | 1 - spikeinterface_gui/spikeamplitudeview.py | 6 ++++++ spikeinterface_gui/spikedepthview.py | 1 + .../tests/test_mainwindow_qt.py | 2 +- spikeinterface_gui/tracemapview.py | 21 ++----------------- spikeinterface_gui/viewlist.py | 5 +++-- 10 files changed, 18 insertions(+), 36 deletions(-) diff --git a/spikeinterface_gui/backend_qt.py b/spikeinterface_gui/backend_qt.py index 3fcc00b..5ac386a 100644 --- a/spikeinterface_gui/backend_qt.py +++ b/spikeinterface_gui/backend_qt.py @@ -398,7 +398,6 @@ def open_help(self): def refresh(self): view = self._view() view.refresh() - areas = { 'right' : QT.Qt.RightDockWidgetArea, diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 96398c0..3dbb2f2 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -62,7 +62,7 @@ def get_selected_spikes_data(self, segment_index=0, visible_inds=None): selected_indices = np.intersect1d(selected_indices, visible_inds) mask = np.isin(sl.start + np.arange(len(spikes_in_seg)), selected_indices) selected_spikes = spikes_in_seg[mask] - spike_times = selected_spikes['sample_index'] / self.controller.sampling_frequency + spike_times = self.controller.sample_index_to_time(selected_spikes['sample_index']) spike_data = self.spike_data[sl][mask] return (spike_times, spike_data) @@ -87,7 +87,7 @@ def select_all_spikes_from_lasso(self, keep_already_selected=False): if vertices is None: continue spike_inds = self.controller.get_spike_indices(visible_unit_id, segment_index=segment_index) - spike_times = self.controller.spikes["sample_index"][spike_inds] / fs + spike_times = self.controller.sample_index_to_time(self.controller.spikes["sample_index"][spike_inds]) spike_data = self.spike_data[spike_inds] points = np.column_stack((spike_times, spike_data)) @@ -191,9 +191,6 @@ def _qt_make_layout(self): self.split_but = QT.QPushButton("split") tb.addWidget(self.split_but) self.split_but.clicked.connect(self.split) - shortcut_split = QT.QShortcut(self.qt_widget) - shortcut_split.setKey(QT.QKeySequence("ctrl+s")) - shortcut_split.activated.connect(self.split) h = QT.QHBoxLayout() self.layout.addLayout(h) @@ -399,8 +396,8 @@ def _panel_make_layout(self): self.scatter_fig.toolbar.active_drag = None self.scatter_fig.xaxis.axis_label = "Time (s)" self.scatter_fig.yaxis.axis_label = self.y_label - time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency - self.scatter_fig.x_range = Range1d(0., time_max) + t_start, t_end = self.controller.get_t_start_t_end() + self.scatter_fig.x_range = Range1d(t_start, t_end) # Add SelectionGeometry event handler to capture lasso vertices self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 78afd1f..e14fa97 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -450,7 +450,7 @@ def get_times_chunk(self, segment_index, t1, t2): recording = self.analyzer.recording times_chunk = recording.get_times(segment_index=segment_index)[ind1:ind2] else: - times_chunk = np.arange(ind2 - ind1, dtype='float64') / self.controller.sampling_frequency + max(t1, 0) + times_chunk = np.arange(ind2 - ind1, dtype='float64') / self.sampling_frequency + max(t1, 0) return times_chunk def get_chunk_indices(self, t1, t2, segment_index): diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index 148eae6..6803a8c 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -64,11 +64,9 @@ def _qt_make_layout(self): from .myqt import QT import pyqtgraph as pg - self.merge_info = {} self.layout = QT.QVBoxLayout() - tb = self.qt_widget.view_toolbar if self.controller.curation_can_be_saved(): but = QT.QPushButton("Save in analyzer") @@ -92,8 +90,6 @@ def _qt_make_layout(self): self.table_delete.customContextMenuRequested.connect(self._qt_open_context_menu_delete) self.table_delete.itemSelectionChanged.connect(self._qt_on_item_selection_changed_delete) - - self.delete_menu = QT.QMenu() act = self.delete_menu.addAction('Restore') act.triggered.connect(self.restore_units) diff --git a/spikeinterface_gui/myqt.py b/spikeinterface_gui/myqt.py index 12ad948..fb34cba 100644 --- a/spikeinterface_gui/myqt.py +++ b/spikeinterface_gui/myqt.py @@ -5,7 +5,6 @@ http://mikeboers.com/blog/2015/07/04/static-libraries-in-a-dynamic-world#the-fold """ - class ModuleProxy(object): def __init__(self, prefixes, modules): diff --git a/spikeinterface_gui/spikeamplitudeview.py b/spikeinterface_gui/spikeamplitudeview.py index 8536da6..c1855ae 100644 --- a/spikeinterface_gui/spikeamplitudeview.py +++ b/spikeinterface_gui/spikeamplitudeview.py @@ -28,10 +28,15 @@ def __init__(self, controller=None, parent=None, backend="qt"): ) def _qt_make_layout(self): + from .myqt import QT super()._qt_make_layout() self.noise_harea = [] if self.settings["noise_level"]: self._qt_add_noise_area() + # add split shortcut, so that it's not duplicated + shortcut_split = QT.QShortcut(self.qt_widget) + shortcut_split.setKey(QT.QKeySequence("ctrl+s")) + shortcut_split.activated.connect(self.split) def _qt_refresh(self): super()._qt_refresh() @@ -95,4 +100,5 @@ def _panel_add_noise_area(self): ### Controls - **select** : activate lasso selection to select individual spikes +- **split** or **ctrl+s** : split the selected spikes into a new unit (only if one unit is visible) """ diff --git a/spikeinterface_gui/spikedepthview.py b/spikeinterface_gui/spikedepthview.py index a357bf3..e1a3702 100644 --- a/spikeinterface_gui/spikedepthview.py +++ b/spikeinterface_gui/spikedepthview.py @@ -25,4 +25,5 @@ def __init__(self, controller=None, parent=None, backend="qt"): ### Controls - **select** : activate lasso selection to select individual spikes +- **split** or **ctrl+s** : split the selected spikes into a new unit (only if one unit is visible) """ diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index ec53ba1..d1998ce 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -71,7 +71,7 @@ def test_mainwindow(start_app=False, verbose=True, curation=False, only_some_ext for segment_index in range(analyzer.get_num_segments()): shift = (segment_index + 1) * 100 # add a gap to times - gap = 1 + gap = 5 times = analyzer.recording.get_times(segment_index) times = times + shift times[len(times)//2:] += gap # add a gap in the middle diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index e122b47..8493ce8 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -143,23 +143,16 @@ def _qt_seek(self, t): self.scroll_time.setPageStep(int(sr*xsize)) self.scroll_time.valueChanged.connect(self._qt_on_scroll_time) -<<<<<<< HEAD segment_index = self.controller.get_time()[1] - times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \ + times_chunk, data_curves, scatter_x, scatter_y, scatter_colors = \ self.get_data_in_chunk(t1, t2, segment_index) + data_curves = data_curves.T if times_chunk.size == 0: self.image.hide() self.scatter.clear() return - -======= - seg_index = self.controller.get_time()[1] - times_chunk, data_curves, scatter_x, scatter_y, scatter_colors = \ - self.get_data_in_chunk(t1, t2, seg_index) - data_curves = data_curves.T ->>>>>>> 4d644768f573b50e1ce051e7bc234fc21bc5ab93 if self.color_limit is None: self.color_limit = np.max(np.abs(data_curves)) @@ -283,14 +276,9 @@ def _panel_refresh(self): else: auto_scale = False -<<<<<<< HEAD - times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \ - self.get_data_in_chunk(t1, t2, segment_index) -======= times_chunk, data_curves, scatter_x, scatter_y, scatter_colors = \ self.get_data_in_chunk(t1, t2, seg_index) data_curves = data_curves.T ->>>>>>> 4d644768f573b50e1ce051e7bc234fc21bc5ab93 if self.color_limit is None: self.color_limit = np.max(np.abs(data_curves)) @@ -320,13 +308,8 @@ def _panel_refresh(self): # TODO: if from a different unit, change unit visibility def _panel_on_tap(self, event): -<<<<<<< HEAD - segment_index = self.controller.get_time()[1] - ind_spike_nearest = self.find_nearest_spike(self.controller, event.x, segment_index) -======= seg_index = self.controller.get_time()[1] ind_spike_nearest = find_nearest_spike(self.controller, event.x, seg_index) ->>>>>>> 4d644768f573b50e1ce051e7bc234fc21bc5ab93 if ind_spike_nearest is not None: self.controller.set_indices_spike_selected([ind_spike_nearest]) self._panel_seek_with_selected_spike() diff --git a/spikeinterface_gui/viewlist.py b/spikeinterface_gui/viewlist.py index b687deb..c48e0ac 100644 --- a/spikeinterface_gui/viewlist.py +++ b/spikeinterface_gui/viewlist.py @@ -17,8 +17,10 @@ from .metricsview import MetricsView from .spikerateview import SpikeRateView +# probe and mainsettings view are first, since they affect other views (e.g., time info) possible_class_views = dict( - probe = ProbeView, # probe view is first, since it updates channels upon unit changes + probe = ProbeView, + mainsettings = MainSettingsView, unitlist = UnitListView, spikelist = SpikeListView, merge = MergeView, @@ -35,5 +37,4 @@ curation = CurationView, spikerate = SpikeRateView, metrics = MetricsView, - mainsettings = MainSettingsView, ) From 7357c6913f17ba80f38efefa10f63f3a9ef97059 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 30 Oct 2025 13:24:04 +0100 Subject: [PATCH 4/8] updates to panel --- spikeinterface_gui/backend_panel.py | 14 ++++++++++++++ spikeinterface_gui/basescatterview.py | 4 ++-- spikeinterface_gui/mainsettingsview.py | 1 + spikeinterface_gui/tracemapview.py | 5 ----- spikeinterface_gui/traceview.py | 23 +++++++++++++++++++---- spikeinterface_gui/view_base.py | 2 +- 6 files changed, 37 insertions(+), 12 deletions(-) diff --git a/spikeinterface_gui/backend_panel.py b/spikeinterface_gui/backend_panel.py index dd44111..41412ca 100644 --- a/spikeinterface_gui/backend_panel.py +++ b/spikeinterface_gui/backend_panel.py @@ -13,6 +13,7 @@ class SignalNotifier(param.Parameterized): channel_visibility_changed = param.Event() manual_curation_updated = param.Event() time_info_updated = param.Event() + use_times_updated = param.Event() active_view_updated = param.Event() unit_color_changed = param.Event() @@ -35,6 +36,9 @@ def notify_manual_curation_updated(self): def notify_time_info_updated(self): self.param.trigger("time_info_updated") + def notify_use_times_updated(self): + self.param.trigger("use_times_updated") + def notify_active_view_updated(self): # this is used to keep an "active view" in the main window # when a view triggers this event, it self-declares it as active @@ -65,6 +69,7 @@ def connect_view(self, view): view.notifier.param.watch(self.on_channel_visibility_changed, "channel_visibility_changed") view.notifier.param.watch(self.on_manual_curation_updated, "manual_curation_updated") view.notifier.param.watch(self.on_time_info_updated, "time_info_updated") + view.notifier.param.watch(self.on_use_times_updated, "use_times_updated") view.notifier.param.watch(self.on_active_view_updated, "active_view_updated") view.notifier.param.watch(self.on_unit_color_changed, "unit_color_changed") @@ -110,6 +115,15 @@ def on_time_info_updated(self, param): continue view.on_time_info_updated() + def on_use_times_updated(self, param): + # use times is updated also when a view is not active + if not self._active: + return + for view in self.controller.views: + if param.obj.view == view: + continue + view.on_use_times_updated() + def on_active_view_updated(self, param): if not self._active: return diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 3dbb2f2..575ac46 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -396,8 +396,8 @@ def _panel_make_layout(self): self.scatter_fig.toolbar.active_drag = None self.scatter_fig.xaxis.axis_label = "Time (s)" self.scatter_fig.yaxis.axis_label = self.y_label - t_start, t_end = self.controller.get_t_start_t_end() - self.scatter_fig.x_range = Range1d(t_start, t_end) + t_start, t_stop = self.controller.get_t_start_t_stop() + self.scatter_fig.x_range = Range1d(t_start, t_stop) # Add SelectionGeometry event handler to capture lasso vertices self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry) diff --git a/spikeinterface_gui/mainsettingsview.py b/spikeinterface_gui/mainsettingsview.py index 6d4c860..a052e44 100644 --- a/spikeinterface_gui/mainsettingsview.py +++ b/spikeinterface_gui/mainsettingsview.py @@ -41,6 +41,7 @@ def on_change_color_mode(self): def on_use_times(self): self.controller.main_settings['use_times'] = self.main_settings['use_times'] + print("Use times changed:", self.main_settings['use_times']) self.controller.update_time_info() self.notify_use_times_updated() # for view in self.controller.views: diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index 8493ce8..5080ce9 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -181,11 +181,6 @@ def _qt_on_time_info_updated(self): # we need refresh in QT because changing tab/docking/undocking doesn't trigger a refresh self.refresh() - def _qt_on_use_times_updated(self): - # Update time seeker - t_start, t_stop = self.controller.get_t_start_t_stop() - self.timeseeker.set_start_stop(t_start, t_stop) - ## Panel ## def _panel_make_layout(self): import panel as pn diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index c88f197..cda953f 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -239,6 +239,11 @@ def _qt_seek_with_selected_spike(self): self.controller.set_time(time=peak_time) self.notify_time_info_updated() self.refresh() + + def _qt_on_use_times_updated(self): + # Update time seeker + t_start, t_stop = self.controller.get_t_start_t_stop() + self.timeseeker.set_start_stop(t_start, t_stop) ## panel ## def _panel_create_toolbar(self): @@ -341,6 +346,16 @@ def _panel_seek_with_selected_spike(self): self.refresh() self.notify_time_info_updated() + def _panel_on_use_times_updated(self): + # Update time seeker + t_start, t_stop = self.controller.get_t_start_t_stop() + self.time_slider.start = t_start + self.time_slider.end = t_stop + + # Optionally clamp the current value if out of range + self.time_slider.value = max(t_start, min(self.time_slider.value, t_stop)) + self.refresh() + # TODO: pan behavior like Qt? # def _panel_on_pan_start(self, event): # self.drag_state["x_start"] = event.x @@ -552,10 +567,10 @@ def _qt_on_time_info_updated(self): # we need refresh in QT because changing tab/docking/undocking doesn't trigger a refresh self.refresh() - def _qt_on_use_times_updated(self): - # Update time seeker - t_start, t_stop = self.controller.get_t_start_t_stop() - self.timeseeker.set_start_stop(t_start, t_stop) + # def _qt_on_use_times_updated(self): + # # Update time seeker + # t_start, t_stop = self.controller.get_t_start_t_stop() + # self.timeseeker.set_start_stop(t_start, t_stop) ## panel ## def _panel_make_layout(self): diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py index 0fb2e2a..00c420a 100644 --- a/spikeinterface_gui/view_base.py +++ b/spikeinterface_gui/view_base.py @@ -322,7 +322,7 @@ def _panel_on_manual_curation_updated(self): def _panel_on_time_info_updated(self): pass - def _panel_use_times_updated(self): + def _panel_on_use_times_updated(self): pass def _panel_on_unit_color_changed(self): From 63b2c12e8478e15b428a20c01f31a86958c424d3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 30 Oct 2025 15:25:14 +0100 Subject: [PATCH 5/8] Fix spikerate and scatter item clicked --- spikeinterface_gui/basescatterview.py | 8 ++- spikeinterface_gui/controller.py | 2 +- spikeinterface_gui/mainsettingsview.py | 1 - spikeinterface_gui/spikerateview.py | 19 ++++--- spikeinterface_gui/tracemapview.py | 27 +++++++--- spikeinterface_gui/traceview.py | 71 ++++++++++---------------- 6 files changed, 66 insertions(+), 62 deletions(-) diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 575ac46..c380eb2 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -461,6 +461,9 @@ def _panel_refresh(self): colors = [] segment_index = self.controller.get_time()[1] + if segment_index != self.segment_index: + self.segment_index = segment_index + self.segment_selector.value = f"Segment {self.segment_index}" visible_unit_ids = self.controller.get_visible_unit_ids() for unit_id in visible_unit_ids: @@ -487,6 +490,9 @@ def _panel_refresh(self): line_width=2, ) self.hist_lines.append(hist_lines) + t_start, t_end = self.controller.get_t_start_t_stop() + self.scatter_fig.x_range.start = t_start + self.scatter_fig.x_range.end = t_end self._max_count = max_count @@ -522,7 +528,7 @@ def _panel_change_segment(self, event): self._current_selected = 0 self.segment_index = int(self.segment_selector.value.split()[-1]) self.controller.set_time(segment_index=self.segment_index) - t_start, t_end = self.controller.get_t_start_t_end() + t_start, t_end = self.controller.get_t_start_t_stop() self.scatter_fig.x_range.start = t_start self.scatter_fig.x_range.end = t_end self.refresh() diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index e14fa97..8587c07 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -614,7 +614,7 @@ def set_indices_spike_selected(self, inds): # set time info segment_index = self.spikes['segment_index'][self._spike_selected_indices[0]] sample_index = self.spikes['sample_index'][self._spike_selected_indices[0]] - self.set_time(time=sample_index / self.sampling_frequency, segment_index=segment_index) + self.set_time(time=self.sample_index_to_time(sample_index), segment_index=segment_index) def get_spike_indices(self, unit_id, segment_index=None): if segment_index is None: diff --git a/spikeinterface_gui/mainsettingsview.py b/spikeinterface_gui/mainsettingsview.py index a052e44..6d4c860 100644 --- a/spikeinterface_gui/mainsettingsview.py +++ b/spikeinterface_gui/mainsettingsview.py @@ -41,7 +41,6 @@ def on_change_color_mode(self): def on_use_times(self): self.controller.main_settings['use_times'] = self.main_settings['use_times'] - print("Use times changed:", self.main_settings['use_times']) self.controller.update_time_info() self.notify_use_times_updated() # for view in self.controller.views: diff --git a/spikeinterface_gui/spikerateview.py b/spikeinterface_gui/spikerateview.py index c1a204e..c44156c 100644 --- a/spikeinterface_gui/spikerateview.py +++ b/spikeinterface_gui/spikerateview.py @@ -19,7 +19,6 @@ def on_time_info_updated(self): self.refresh() def on_use_times_updated(self): - print(f"Refreshing SpikeRateView") self.refresh() ## Qt ## @@ -105,25 +104,26 @@ def _panel_make_layout(self): self.segment_selector.param.watch(self._panel_change_segment, 'value') self.rate_fig = bpl.figure( - width=250, - height=250, tools="pan,wheel_zoom,reset", active_drag="pan", active_scroll="wheel_zoom", background_fill_color=_bg_color, border_fill_color=_bg_color, outline_line_color="white", + sizing_mode="stretch_both", ) self.rate_fig.toolbar.logo = None self.rate_fig.grid.visible = False self.layout = pn.Column( pn.Row(self.segment_selector, sizing_mode="stretch_width"), - pn.Row(self.rate_fig,sizing_mode="stretch_both"), + pn.Row(self.rate_fig, sizing_mode="stretch_both"), ) self.is_warning_active = False def _panel_refresh(self): + from bokeh.models import Range1d + segment_index = self.controller.get_time()[1] if segment_index != self.segment_index: self.segment_index = segment_index @@ -136,11 +136,12 @@ def _panel_refresh(self): total_frames = self.controller.final_spike_samples bins_s = self.settings['bin_s'] num_bins = total_frames[segment_index] // int(sampling_frequency) // bins_s - t_start, _ = self.controller.get_t_start_t_stop() + t_start, t_stop = self.controller.get_t_start_t_stop() # clear fig self.rate_fig.renderers = [] + max_count = 0 for unit_id in visible_unit_ids: spike_inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index) @@ -152,13 +153,15 @@ def _panel_refresh(self): color = self.get_unit_color(unit_id) line = self.rate_fig.line( - x=(bins[1:]+bins[:-1])/(2*sampling_frequency) + t_start, - y=count/bins_s, + x=(bins[1:]+bins[:-1]) / (2*sampling_frequency) + t_start, + y=count / bins_s, color=color, line_width=2, ) + max_count = max(max_count, np.max(count/bins_s)) - self.rate_fig.y_range.start = 0 + self.rate_fig.x_range = Range1d(start=t_start, end=t_stop) + self.rate_fig.y_range = Range1d(start=0, end=max_count*1.2) def _panel_change_segment(self, event): self.segment_index = int(self.segment_selector.value.split()[-1]) diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index 5080ce9..f26df9a 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -115,11 +115,6 @@ def _qt_on_settings_changed(self, do_refresh=True): def _qt_on_spike_selection_changed(self): self._qt_seek_with_selected_spike() - - def _qt_scatter_item_clicked(self, x, y): - # useless but needed for the MixinViewTrace - pass - def _qt_refresh(self): t, _ = self.controller.get_time() self._qt_seek(t) @@ -181,6 +176,12 @@ def _qt_on_time_info_updated(self): # we need refresh in QT because changing tab/docking/undocking doesn't trigger a refresh self.refresh() + def _qt_on_use_times_updated(self): + # Update time seeker + t_start, t_stop = self.controller.get_t_start_t_stop() + self.timeseeker.set_start_stop(t_start, t_stop) + self.timeseeker.seek(self.controller.get_time()[0]) + ## Panel ## def _panel_make_layout(self): import panel as pn @@ -272,7 +273,7 @@ def _panel_refresh(self): auto_scale = False times_chunk, data_curves, scatter_x, scatter_y, scatter_colors = \ - self.get_data_in_chunk(t1, t2, seg_index) + self.get_data_in_chunk(t1, t2, segment_index) data_curves = data_curves.T if self.color_limit is None: @@ -304,7 +305,8 @@ def _panel_refresh(self): # TODO: if from a different unit, change unit visibility def _panel_on_tap(self, event): seg_index = self.controller.get_time()[1] - ind_spike_nearest = find_nearest_spike(self.controller, event.x, seg_index) + time = event.x + ind_spike_nearest = find_nearest_spike(self.controller, time, seg_index) if ind_spike_nearest is not None: self.controller.set_indices_spike_selected([ind_spike_nearest]) self._panel_seek_with_selected_spike() @@ -341,6 +343,17 @@ def _panel_on_time_info_updated(self): self._block_auto_refresh_and_notify = False # we don't need a refresh in panel because changing tab triggers a refresh + def _panel_on_use_times_updated(self): + # Update time seeker + t_start, t_stop = self.controller.get_t_start_t_stop() + self.time_slider.start = t_start + self.time_slider.end = t_stop + + # Optionally clamp the current value if out of range + self.time_slider.value = self.controller.get_time()[0] + + self.refresh() + TraceMapView._gui_help_txt = """ ## Trace Map View diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index cda953f..1075a1f 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -145,7 +145,6 @@ def _qt_initialize_plot(self): self.graphicsview.setCentralItem(self.plot) self.plot.hideButtons() self.plot.showAxis('left', False) - self.viewBox.doubleclicked.connect(self._qt_scatter_item_clicked) self.viewBox.gain_zoom.connect(self.apply_gain_zoom) @@ -240,11 +239,14 @@ def _qt_seek_with_selected_spike(self): self.notify_time_info_updated() self.refresh() - def _qt_on_use_times_updated(self): - # Update time seeker - t_start, t_stop = self.controller.get_t_start_t_stop() - self.timeseeker.set_start_stop(t_start, t_stop) - + def _qt_scatter_item_clicked(self, x, y): + ind_spike_nearest = find_nearest_spike(self.controller, x, segment_index=self.controller.get_time()[1]) + if ind_spike_nearest is not None: + self.controller.set_indices_spike_selected([ind_spike_nearest]) + + self.notify_spike_selection_changed() + self._qt_seek_with_selected_spike() + ## panel ## def _panel_create_toolbar(self): import panel as pn @@ -346,16 +348,6 @@ def _panel_seek_with_selected_spike(self): self.refresh() self.notify_time_info_updated() - def _panel_on_use_times_updated(self): - # Update time seeker - t_start, t_stop = self.controller.get_t_start_t_stop() - self.time_slider.start = t_start - self.time_slider.end = t_stop - - # Optionally clamp the current value if out of range - self.time_slider.value = max(t_start, min(self.time_slider.value, t_stop)) - self.refresh() - # TODO: pan behavior like Qt? # def _panel_on_pan_start(self, event): # self.drag_state["x_start"] = event.x @@ -476,25 +468,6 @@ def _qt_on_settings_changed(self): def _qt_on_spike_selection_changed(self): MixinViewTrace._qt_seek_with_selected_spike(self) - def _qt_scatter_item_clicked(self, x, y): - # TODO sam : make it faster without boolean mask - ind_click = int(x*self.controller.sampling_frequency ) - in_seg, = np.nonzero(self.controller.spikes['segment_index'] == self.controller.get_time()[1]) - nearest = np.argmin(np.abs(self.controller.spikes[in_seg]['sample_index'] - ind_click)) - - ind_spike_nearest = in_seg[nearest] - sample_index = self.controller.spikes[ind_spike_nearest]['sample_index'] - - if np.abs(ind_click - sample_index) > (self.controller.sampling_frequency // 30): - return - - #~ self.controller.spikes['selected'][:] = False - #~ self.controller.spikes['selected'][ind_spike_nearest] = True - self.controller.set_indices_spike_selected([ind_spike_nearest]) - - self.notify_spike_selection_changed() - self.refresh() - def _qt_refresh(self): t, _ = self.controller.get_time() self._qt_seek(t) @@ -567,10 +540,11 @@ def _qt_on_time_info_updated(self): # we need refresh in QT because changing tab/docking/undocking doesn't trigger a refresh self.refresh() - # def _qt_on_use_times_updated(self): - # # Update time seeker - # t_start, t_stop = self.controller.get_t_start_t_stop() - # self.timeseeker.set_start_stop(t_start, t_stop) + def _qt_on_use_times_updated(self): + # Update time seeker + t_start, t_stop = self.controller.get_t_start_t_stop() + self.timeseeker.set_start_stop(t_start, t_stop) + self.timeseeker.seek(self.controller.get_time()[0]) ## panel ## def _panel_make_layout(self): @@ -679,7 +653,8 @@ def _panel_refresh(self): # TODO: if from a different unit, change unit visibility def _panel_on_tap(self, event): - ind_spike_nearest = find_nearest_spike(self.controller, event.x, self.controller.get_time()[1]) + time = event.x + ind_spike_nearest = find_nearest_spike(self.controller, time, self.controller.get_time()[1]) if ind_spike_nearest is not None: self.controller.set_indices_spike_selected([ind_spike_nearest]) self._panel_seek_with_selected_spike() @@ -693,8 +668,7 @@ def _panel_gain_zoom(self, event): factor_ratio = 1.3 if event.delta > 0 else 1 / 1.3 else: factor_ratio = 1.0 - factor = 1.3 if event.delta > 0 else 1 / 1.3 - self.apply_gain_zoom(factor) + self.apply_gain_zoom(factor_ratio) def _panel_auto_scale(self, event): self.auto_scale() @@ -709,14 +683,23 @@ def _panel_on_time_info_updated(self): self._block_auto_refresh = False # we don't need a refresh in panel because changing tab triggers a refresh + def _panel_on_use_times_updated(self): + # Update time seeker + t_start, t_stop = self.controller.get_t_start_t_stop() + self.time_slider.start = t_start + self.time_slider.end = t_stop + + # Optionally clamp the current value if out of range + self.time_slider.value = self.controller.get_time()[0] + self.refresh() + -# TODO sam refactor Qt and this def find_nearest_spike(controller, x, segment_index, max_distance_samples=None): if max_distance_samples is None: max_distance_samples = controller.sampling_frequency // 30 - ind_click = int(x * controller.sampling_frequency) + ind_click = controller.time_to_sample_index(x) (in_seg,) = np.nonzero(controller.spikes["segment_index"] == segment_index) if len(in_seg) == 0: From 0f520a87428f7609bb86172f0e165a108f4be98f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 30 Oct 2025 18:10:14 +0100 Subject: [PATCH 6/8] Fix double refresh based on time info --- spikeinterface_gui/basescatterview.py | 15 ++++++++++----- spikeinterface_gui/spikerateview.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index c380eb2..9a87174 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -32,6 +32,7 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q self._lasso_vertices = {segment_index: None for segment_index in range(controller.num_segments)} # this is used in panel self._current_selected = 0 + self._block_auto_refresh_and_notify = False ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) @@ -164,11 +165,14 @@ def on_unit_visibility_changed(self): self._current_selected = self.controller.get_indices_spike_selected().size self.refresh() - def on_time_info_updated(self): - return self.refresh() + def _qt_on_time_info_updated(self): + if self.combo_seg.currentIndex() != self.controller.get_time()[1]: + self._block_auto_refresh_and_notify = True + self.refresh() + self._block_auto_refresh_and_notify = False def on_use_times_updated(self): - return self.refresh() + self.refresh() ## QT zone ## def _qt_make_layout(self): @@ -242,8 +246,9 @@ def _qt_on_spike_selection_changed(self): def _qt_change_segment(self): segment_index = self.combo_seg.currentIndex() self.controller.set_time(segment_index=segment_index) - self.refresh() - self.notify_time_info_updated() + if not self._block_auto_refresh_and_notify: + self.refresh() + self.notify_time_info_updated() def _qt_refresh(self): from .myqt import QT diff --git a/spikeinterface_gui/spikerateview.py b/spikeinterface_gui/spikerateview.py index c44156c..18518ac 100644 --- a/spikeinterface_gui/spikerateview.py +++ b/spikeinterface_gui/spikerateview.py @@ -11,12 +11,16 @@ class SpikeRateView(ViewBase): def __init__(self, controller=None, parent=None, backend="qt"): ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) + self._block_auto_refresh_and_notify = False def _on_settings_changed(self): self.refresh() - def on_time_info_updated(self): - self.refresh() + def _qt_on_time_info_updated(self): + if self.combo_seg.currentIndex() != self.controller.get_time()[1]: + self._block_auto_refresh_and_notify = True + self.refresh() + self._block_auto_refresh_and_notify = False def on_use_times_updated(self): self.refresh() @@ -46,8 +50,9 @@ def _qt_make_layout(self): def _qt_change_segment(self): segment_index = self.combo_seg.currentIndex() self.controller.set_time(segment_index=segment_index) - self.refresh() - self.notify_time_info_updated() + if not self._block_auto_refresh_and_notify: + self.refresh() + self.notify_time_info_updated() def _qt_refresh(self): import pyqtgraph as pg From 23448c9ed1b8fab863610f77f6d9bc99c4efc60f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 31 Oct 2025 12:56:54 +0100 Subject: [PATCH 7/8] Avoid recursive calls when time changes and remove self.segment_index --- spikeinterface_gui/basescatterview.py | 23 +++++++++------- spikeinterface_gui/controller.py | 21 +++++++-------- spikeinterface_gui/spikerateview.py | 14 +++++----- spikeinterface_gui/tracemapview.py | 25 ++++++----------- spikeinterface_gui/traceview.py | 39 ++++++++++++++------------- 5 files changed, 59 insertions(+), 63 deletions(-) diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index 9a87174..f36bf7c 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -361,11 +361,11 @@ def _panel_make_layout(self): self.lasso_tool = LassoSelectTool() - self.segment_index = 0 + segment_index = self.controller.get_time()[1] self.segment_selector = pn.widgets.Select( name="", options=[f"Segment {i}" for i in range(self.controller.num_segments)], - value=f"Segment {self.segment_index}", + value=f"Segment {segment_index}", ) self.segment_selector.param.watch(self._panel_change_segment, 'value') @@ -466,9 +466,10 @@ def _panel_refresh(self): colors = [] segment_index = self.controller.get_time()[1] - if segment_index != self.segment_index: - self.segment_index = segment_index - self.segment_selector.value = f"Segment {self.segment_index}" + # get view segment index from segment selector + segment_index_from_selector = self.segment_selector.options.index(self.segment_selector.value) + if segment_index != segment_index_from_selector: + self.segment_selector.value = f"Segment {segment_index}" visible_unit_ids = self.controller.get_visible_unit_ids() for unit_id in visible_unit_ids: @@ -531,8 +532,8 @@ def _panel_on_select_button(self, event): def _panel_change_segment(self, event): self._current_selected = 0 - self.segment_index = int(self.segment_selector.value.split()[-1]) - self.controller.set_time(segment_index=self.segment_index) + segment_index = int(self.segment_selector.value.split()[-1]) + self.controller.set_time(segment_index=segment_index) t_start, t_end = self.controller.get_t_start_t_stop() self.scatter_fig.x_range.start = t_start self.scatter_fig.x_range.end = t_end @@ -555,7 +556,7 @@ def _on_panel_selection_geometry(self, event): return # Append the current polygon to the lasso vertices if shift is held - segment_index = self.segment_index + segment_index = self.controller.get_time()[1] if self._lasso_vertices[segment_index] is None: self._lasso_vertices[segment_index] = [] if len(selected) > self._current_selected: @@ -582,7 +583,8 @@ def _panel_update_selected_spikes(self): selected_spike_indices = np.intersect1d(selected_spike_indices, self.plotted_inds) if len(selected_spike_indices) > 0: # map absolute indices to visible spikes - sl = self.controller.segment_slices[self.segment_index] + segment_index = self.controller.get_time()[1] + sl = self.controller.segment_slices[segment_index] spikes_in_seg = self.controller.spikes[sl] visible_mask = np.zeros(len(spikes_in_seg), dtype=bool) for unit_index, unit_id in self.controller.iter_visible_units(): @@ -604,7 +606,8 @@ def _panel_on_spike_selection_changed(self): return elif len(selected_indices) == 1: selected_segment = self.controller.spikes[selected_indices[0]]['segment_index'] - if selected_segment != self.segment_index: + segment_index = self.controller.get_time()[1] + if selected_segment != segment_index: self.segment_selector.value = f"Segment {selected_segment}" self._panel_change_segment(None) # update selected spikes diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 8587c07..e54177e 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -60,11 +60,8 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self.verbose = verbose t0 = time.perf_counter() - self.main_settings = _default_main_settings.copy() - - self.num_channels = self.analyzer.get_num_channels() # this now private and shoudl be acess using function self._visible_unit_ids = [self.unit_ids[0]] @@ -421,19 +418,21 @@ def set_time(self, time=None, segment_index=None): def update_time_info(self): # set default time info if self.main_settings["use_times"] and self.analyzer.has_recording(): - self.time_info = dict( - time_by_seg=np.array( - [ - self.analyzer.recording.get_start_time(segment_index) for segment_index in range(self.num_segments) - ], - dtype="float64"), - segment_index=0 + time_by_seg=np.array( + [ + self.analyzer.recording.get_start_time(segment_index) for segment_index in range(self.num_segments) + ], + dtype="float64" ) else: + time_by_seg=np.array([0] * self.num_segments, dtype="float64") + if not hasattr(self, 'time_info'): self.time_info = dict( - time_by_seg=np.array([0] * self.num_segments, dtype="float64"), + time_by_seg=time_by_seg, segment_index=0 ) + else: + self.time_info['time_by_seg'] = time_by_seg def get_t_start_t_stop(self): segment_index = self.time_info["segment_index"] diff --git a/spikeinterface_gui/spikerateview.py b/spikeinterface_gui/spikerateview.py index 18518ac..38163d6 100644 --- a/spikeinterface_gui/spikerateview.py +++ b/spikeinterface_gui/spikerateview.py @@ -100,11 +100,11 @@ def _panel_make_layout(self): import bokeh.plotting as bpl from .utils_panel import _bg_color - self.segment_index = 0 + segment_index = self.controller.get_time()[1] self.segment_selector = pn.widgets.Select( name="", options=[f"Segment {i}" for i in range(self.controller.num_segments)], - value=f"Segment {self.segment_index}", + value=f"Segment {segment_index}", ) self.segment_selector.param.watch(self._panel_change_segment, 'value') @@ -130,9 +130,9 @@ def _panel_refresh(self): from bokeh.models import Range1d segment_index = self.controller.get_time()[1] - if segment_index != self.segment_index: - self.segment_index = segment_index - self.segment_selector.value = f"Segment {self.segment_index}" + segment_index_from_selector = self.segment_selector.options.index(self.segment_selector.value) + if segment_index != segment_index_from_selector: + self.segment_selector.value = f"Segment {segment_index}" visible_unit_ids = self.controller.get_visible_unit_ids() @@ -169,8 +169,8 @@ def _panel_refresh(self): self.rate_fig.y_range = Range1d(start=0, end=max_count*1.2) def _panel_change_segment(self, event): - self.segment_index = int(self.segment_selector.value.split()[-1]) - self.controller.set_time(segment_index=self.segment_index) + segment_index = int(self.segment_selector.value.split()[-1]) + self.controller.set_time(segment_index=segment_index) self.refresh() self.notify_time_info_updated() diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index f26df9a..97818d0 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -168,19 +168,19 @@ def _qt_on_time_info_updated(self): time, segment_index = self.controller.get_time() # Block auto refresh to avoid recursive calls self._block_auto_refresh_and_notify = True - self._qt_change_segment(segment_index) self.timeseeker.seek(time) - - self._block_auto_refresh_and_notify = False - # we need refresh in QT because changing tab/docking/undocking doesn't trigger a refresh self.refresh() + self._block_auto_refresh_and_notify = False def _qt_on_use_times_updated(self): - # Update time seeker + # Block auto refresh to avoid recursive calls + self._block_auto_refresh_and_notify = True t_start, t_stop = self.controller.get_t_start_t_stop() self.timeseeker.set_start_stop(t_start, t_stop) self.timeseeker.seek(self.controller.get_time()[0]) + self.refresh() + self._block_auto_refresh_and_notify = False ## Panel ## def _panel_make_layout(self): @@ -188,7 +188,7 @@ def _panel_make_layout(self): import bokeh.plotting as bpl from .utils_panel import _bg_color from bokeh.models import ColumnDataSource, LinearColorMapper, Range1d - from bokeh.events import MouseWheel, Tap + from bokeh.events import MouseWheel, DoubleTap # Create figure @@ -203,7 +203,7 @@ def _panel_make_layout(self): self.figure.toolbar.logo = None self.figure.on_event(MouseWheel, self._panel_gain_zoom) - self.figure.on_event(Tap, self._panel_on_tap) + self.figure.on_event(DoubleTap, self._panel_on_double_tap) # Add selection line self.selection_line = self.figure.line( @@ -302,16 +302,6 @@ def _panel_refresh(self): self.figure.x_range.end = t2 self.figure.y_range.end = data_curves.shape[1] - # TODO: if from a different unit, change unit visibility - def _panel_on_tap(self, event): - seg_index = self.controller.get_time()[1] - time = event.x - ind_spike_nearest = find_nearest_spike(self.controller, time, seg_index) - if ind_spike_nearest is not None: - self.controller.set_indices_spike_selected([ind_spike_nearest]) - self._panel_seek_with_selected_spike() - self.notify_spike_selection_changed() - def _panel_on_settings_changed(self): self.make_color_lut() self.refresh() @@ -365,4 +355,5 @@ def _panel_on_use_times_updated(self): * **auto scale**: Automatically adjust the scale of the traces. * **time (s)**: Set the time point to display traces. * **mouse wheel**: change the scale of the traces. +* **double click**: select the nearest spike and center the view on it. """ \ No newline at end of file diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index 1075a1f..da46b3c 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -210,7 +210,7 @@ def _qt_xsize_zoom(self, xmove): factor = xmove/100. newsize = self.xsize*(factor+1.) limits = self.spinbox_xsize.opts['bounds'] - if newsize>0. and newsize 0. and newsize < limits[1]: self.spinbox_xsize.setValue(newsize) def _qt_on_scroll_time(self, val): @@ -348,6 +348,14 @@ def _panel_seek_with_selected_spike(self): self.refresh() self.notify_time_info_updated() + def _panel_on_double_tap(self, event): + time = event.x + ind_spike_nearest = find_nearest_spike(self.controller, time, self.controller.get_time()[1]) + if ind_spike_nearest is not None: + self.controller.set_indices_spike_selected([ind_spike_nearest]) + self._panel_seek_with_selected_spike() + self.notify_spike_selection_changed() + # TODO: pan behavior like Qt? # def _panel_on_pan_start(self, event): # self.drag_state["x_start"] = event.x @@ -535,16 +543,17 @@ def _qt_on_time_info_updated(self): self._qt_change_segment(segment_index) self.timeseeker.seek(time) - - self._block_auto_refresh_and_notify = False - # we need refresh in QT because changing tab/docking/undocking doesn't trigger a refresh self.refresh() + self._block_auto_refresh_and_notify = False def _qt_on_use_times_updated(self): # Update time seeker + self._block_auto_refresh_and_notify = True t_start, t_stop = self.controller.get_t_start_t_stop() self.timeseeker.set_start_stop(t_start, t_stop) self.timeseeker.seek(self.controller.get_time()[0]) + self.refresh() + self._block_auto_refresh_and_notify = False ## panel ## def _panel_make_layout(self): @@ -552,7 +561,7 @@ def _panel_make_layout(self): import bokeh.plotting as bpl from .utils_panel import _bg_color from bokeh.models import ColumnDataSource, Range1d - from bokeh.events import Tap, MouseWheel + from bokeh.events import DoubleTap, MouseWheel self.figure = bpl.figure( sizing_mode="stretch_both", @@ -592,7 +601,7 @@ def _panel_make_layout(self): x="x", y="y", size=10, fill_color="color", fill_alpha=self.settings['alpha'], source=self.spike_source ) - self.figure.on_event(Tap, self._panel_on_tap) + self.figure.on_event(DoubleTap, self._panel_on_double_tap) self._panel_create_toolbar() @@ -652,13 +661,6 @@ def _panel_refresh(self): self.figure.y_range.end = n - 0.5 # TODO: if from a different unit, change unit visibility - def _panel_on_tap(self, event): - time = event.x - ind_spike_nearest = find_nearest_spike(self.controller, time, self.controller.get_time()[1]) - if ind_spike_nearest is not None: - self.controller.set_indices_spike_selected([ind_spike_nearest]) - self._panel_seek_with_selected_spike() - self.notify_spike_selection_changed() def _panel_on_spike_selection_changed(self): self._panel_seek_with_selected_spike() @@ -676,22 +678,22 @@ def _panel_auto_scale(self, event): def _panel_on_time_info_updated(self): # Update segment and time slider range time, segment_index = self.controller.get_time() - self._block_auto_refresh = True + self._block_auto_refresh_and_notify = True self._panel_change_segment(segment_index) # Update time slider value self.time_slider.value = time - self._block_auto_refresh = False - # we don't need a refresh in panel because changing tab triggers a refresh + self.refresh() + self._block_auto_refresh_and_notify = False def _panel_on_use_times_updated(self): # Update time seeker t_start, t_stop = self.controller.get_t_start_t_stop() + self._block_auto_refresh_and_notify = True self.time_slider.start = t_start self.time_slider.end = t_stop - - # Optionally clamp the current value if out of range self.time_slider.value = self.controller.get_time()[0] self.refresh() + self._block_auto_refresh_and_notify = False @@ -725,4 +727,5 @@ def find_nearest_spike(controller, x, segment_index, max_distance_samples=None): * **auto scale**: Automatically adjust the scale of the traces. * **time (s)**: Set the time point to display traces. * **mouse wheel**: change the scale of the traces. +* **double click**: select the nearest spike and center the view on it. """ From 1a49a15a978c520c19f7c43e246d869384093d3a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 3 Nov 2025 16:05:01 +0100 Subject: [PATCH 8/8] remove time info updated when spike is selected --- spikeinterface_gui/traceview.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index da46b3c..f7e8ea6 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -236,7 +236,6 @@ def _qt_seek_with_selected_spike(self): self.spinbox_xsize.sigValueChanged.connect(self._qt_on_xsize_changed) self.controller.set_time(time=peak_time) - self.notify_time_info_updated() self.refresh() def _qt_scatter_item_clicked(self, x, y):