Skip to content
Merged
14 changes: 14 additions & 0 deletions spikeinterface_gui/backend_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class SignalNotifier(param.Parameterized):
channel_visibility_changed = param.Event()
manual_curation_updated = param.Event()
time_info_updated = param.Event()
use_times_updated = param.Event()
active_view_updated = param.Event()
unit_color_changed = param.Event()

Expand All @@ -35,6 +36,9 @@ def notify_manual_curation_updated(self):
def notify_time_info_updated(self):
self.param.trigger("time_info_updated")

def notify_use_times_updated(self):
self.param.trigger("use_times_updated")

def notify_active_view_updated(self):
# this is used to keep an "active view" in the main window
# when a view triggers this event, it self-declares it as active
Expand Down Expand Up @@ -65,6 +69,7 @@ def connect_view(self, view):
view.notifier.param.watch(self.on_channel_visibility_changed, "channel_visibility_changed")
view.notifier.param.watch(self.on_manual_curation_updated, "manual_curation_updated")
view.notifier.param.watch(self.on_time_info_updated, "time_info_updated")
view.notifier.param.watch(self.on_use_times_updated, "use_times_updated")
view.notifier.param.watch(self.on_active_view_updated, "active_view_updated")
view.notifier.param.watch(self.on_unit_color_changed, "unit_color_changed")

Expand Down Expand Up @@ -110,6 +115,15 @@ def on_time_info_updated(self, param):
continue
view.on_time_info_updated()

def on_use_times_updated(self, param):
# use times is updated also when a view is not active
if not self._active:
return
for view in self.controller.views:
if param.obj.view == view:
continue
view.on_use_times_updated()

def on_active_view_updated(self, param):
if not self._active:
return
Expand Down
17 changes: 15 additions & 2 deletions spikeinterface_gui/backend_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SignalNotifier(QT.QObject):
channel_visibility_changed = QT.pyqtSignal()
manual_curation_updated = QT.pyqtSignal()
time_info_updated = QT.pyqtSignal()
use_times_updated = QT.pyqtSignal()
unit_color_changed = QT.pyqtSignal()

def __init__(self, parent=None, view=None):
Expand All @@ -40,6 +41,9 @@ def notify_manual_curation_updated(self):
def notify_time_info_updated(self):
self.time_info_updated.emit()

def notify_use_times_updated(self):
self.use_times_updated.emit()

def notify_unit_color_changed(self):
self.unit_color_changed.emit()

Expand All @@ -63,6 +67,7 @@ def connect_view(self, view):
view.notifier.channel_visibility_changed.connect(self.on_channel_visibility_changed)
view.notifier.manual_curation_updated.connect(self.on_manual_curation_updated)
view.notifier.time_info_updated.connect(self.on_time_info_updated)
view.notifier.use_times_updated.connect(self.on_use_times_updated)
view.notifier.unit_color_changed.connect(self.on_unit_color_changed)

def on_spike_selection_changed(self):
Expand Down Expand Up @@ -110,7 +115,16 @@ def on_time_info_updated(self):
# do not refresh it self
continue
view.on_time_info_updated()


def on_use_times_updated(self):
if not self._active:
return
for view in self.controller.views:
if view.qt_widget == self.sender().parent():
# do not refresh it self
continue
view.on_use_times_updated()

def on_unit_color_changed(self):
if not self._active:
return
Expand Down Expand Up @@ -383,7 +397,6 @@ def open_help(self):
def refresh(self):
view = self._view()
view.refresh()


areas = {
'right' : QT.Qt.RightDockWidgetArea,
Expand Down
108 changes: 71 additions & 37 deletions spikeinterface_gui/basescatterview.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q
self._lasso_vertices = {segment_index: None for segment_index in range(controller.num_segments)}
# this is used in panel
self._current_selected = 0
self._block_auto_refresh_and_notify = False

ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)


def get_unit_data(self, unit_id, seg_index=0):
inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index)
spike_times = self.controller.spikes["sample_index"][inds] / self.controller.sampling_frequency
def get_unit_data(self, unit_id, segment_index=0):
inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index)
spike_indices = self.controller.spikes["sample_index"][inds]
spike_times = self.controller.sample_index_to_time(spike_indices)
spike_data = self.spike_data[inds]
ptp = np.ptp(spike_data)
hist_min, hist_max = [np.min(spike_data) - 0.2 * ptp, np.max(spike_data) + 0.2 * ptp]
Expand All @@ -53,15 +55,15 @@ def get_unit_data(self, unit_id, seg_index=0):

return spike_times, spike_data, hist_count, hist_bins, inds

def get_selected_spikes_data(self, seg_index=0, visible_inds=None):
sl = self.controller.segment_slices[seg_index]
def get_selected_spikes_data(self, segment_index=0, visible_inds=None):
sl = self.controller.segment_slices[segment_index]
spikes_in_seg = self.controller.spikes[sl]
selected_indices = self.controller.get_indices_spike_selected()
if visible_inds is not None:
selected_indices = np.intersect1d(selected_indices, visible_inds)
mask = np.isin(sl.start + np.arange(len(spikes_in_seg)), selected_indices)
selected_spikes = spikes_in_seg[mask]
spike_times = selected_spikes['sample_index'] / self.controller.sampling_frequency
spike_times = self.controller.sample_index_to_time(selected_spikes['sample_index'])
spike_data = self.spike_data[sl][mask]
return (spike_times, spike_data)

Expand All @@ -85,8 +87,8 @@ def select_all_spikes_from_lasso(self, keep_already_selected=False):
for segment_index, vertices in self._lasso_vertices.items():
if vertices is None:
continue
spike_inds = self.controller.get_spike_indices(visible_unit_id, seg_index=segment_index)
spike_times = self.controller.spikes["sample_index"][spike_inds] / fs
spike_inds = self.controller.get_spike_indices(visible_unit_id, segment_index=segment_index)
spike_times = self.controller.sample_index_to_time(self.controller.spikes["sample_index"][spike_inds])
spike_data = self.spike_data[spike_inds]

points = np.column_stack((spike_times, spike_data))
Expand Down Expand Up @@ -119,7 +121,7 @@ def split(self):

if self.controller.num_segments > 1:
# check that lasso vertices are defined for all segments
if not all(self._lasso_vertices[seg_index] is not None for seg_index in range(self.controller.num_segments)):
if not all(self._lasso_vertices[segment_index] is not None for segment_index in range(self.controller.num_segments)):
# Use the new continue_from_user pattern
self.continue_from_user(
"Not all segments have lasso selection. "
Expand Down Expand Up @@ -163,6 +165,15 @@ def on_unit_visibility_changed(self):
self._current_selected = self.controller.get_indices_spike_selected().size
self.refresh()

def _qt_on_time_info_updated(self):
if self.combo_seg.currentIndex() != self.controller.get_time()[1]:
self._block_auto_refresh_and_notify = True
self.refresh()
self._block_auto_refresh_and_notify = False

def on_use_times_updated(self):
self.refresh()

## QT zone ##
def _qt_make_layout(self):
from .myqt import QT
Expand All @@ -174,8 +185,8 @@ def _qt_make_layout(self):
tb = self.qt_widget.view_toolbar
self.combo_seg = QT.QComboBox()
tb.addWidget(self.combo_seg)
self.combo_seg.addItems([ f'Segment {seg_index}' for seg_index in range(self.controller.num_segments) ])
self.combo_seg.currentIndexChanged.connect(self.refresh)
self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ])
self.combo_seg.currentIndexChanged.connect(self._qt_change_segment)
add_stretch_to_qtoolbar(tb)
self.lasso_but = QT.QPushButton("select", checkable = True)
tb.addWidget(self.lasso_but)
Expand All @@ -184,9 +195,6 @@ def _qt_make_layout(self):
self.split_but = QT.QPushButton("split")
tb.addWidget(self.split_but)
self.split_but.clicked.connect(self.split)
shortcut_split = QT.QShortcut(self.qt_widget)
shortcut_split.setKey(QT.QKeySequence("ctrl+s"))
shortcut_split.activated.connect(self.split)
h = QT.QHBoxLayout()
self.layout.addLayout(h)

Expand Down Expand Up @@ -235,6 +243,13 @@ def _qt_initialize_plot(self):
def _qt_on_spike_selection_changed(self):
self.refresh()

def _qt_change_segment(self):
segment_index = self.combo_seg.currentIndex()
self.controller.set_time(segment_index=segment_index)
if not self._block_auto_refresh_and_notify:
self.refresh()
self.notify_time_info_updated()

def _qt_refresh(self):
from .myqt import QT
import pyqtgraph as pg
Expand All @@ -246,13 +261,18 @@ def _qt_refresh(self):
if self.spike_data is None:
return

segment_index = self.controller.get_time()[1]
# Update combo_seg if it doesn't match the current segment index
if self.combo_seg.currentIndex() != segment_index:
self.combo_seg.setCurrentIndex(segment_index)

max_count = 1
all_inds = []
for unit_id in self.controller.get_visible_unit_ids():

spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data(
unit_id,
seg_index=self.combo_seg.currentIndex()
segment_index=segment_index
)

# make a copy of the color
Expand All @@ -276,7 +296,7 @@ def _qt_refresh(self):
y_range_plot_1 = self.plot.getViewBox().viewRange()
self.viewBox2.setYRange(y_range_plot_1[1][0], y_range_plot_1[1][1], padding = 0.0)

spike_times, spike_data = self.get_selected_spikes_data(seg_index=self.combo_seg.currentIndex(), visible_inds=all_inds)
spike_times, spike_data = self.get_selected_spikes_data(segment_index=self.combo_seg.currentIndex(), visible_inds=all_inds)

self.scatter_select.setData(spike_times, spike_data)

Expand All @@ -296,8 +316,8 @@ def _qt_on_lasso_finished(self, points, shift_held=False):
self.lasso.setData([], [])
vertices = np.array(points)

seg_index = self.combo_seg.currentIndex()
sl = self.controller.segment_slices[seg_index]
segment_index = self.combo_seg.currentIndex()
sl = self.controller.segment_slices[segment_index]
spikes_in_seg = self.controller.spikes[sl]

# Create mask for visible units
Expand All @@ -315,16 +335,16 @@ def _qt_on_lasso_finished(self, points, shift_held=False):
self.notify_spike_selection_changed()
return

if self._lasso_vertices[seg_index] is None:
self._lasso_vertices[seg_index] = []
if self._lasso_vertices[segment_index] is None:
self._lasso_vertices[segment_index] = []

if shift_held:
# If shift is held, append the vertices to the current lasso vertices
self._lasso_vertices[seg_index].append(vertices)
self._lasso_vertices[segment_index].append(vertices)
keep_already_selected = True
else:
# If shift is not held, clear the existing lasso vertices for this segment
self._lasso_vertices[seg_index] = [vertices]
self._lasso_vertices[segment_index] = [vertices]
keep_already_selected = False

self.select_all_spikes_from_lasso(keep_already_selected=keep_already_selected)
Expand All @@ -341,11 +361,11 @@ def _panel_make_layout(self):

self.lasso_tool = LassoSelectTool()

self.segment_index = 0
segment_index = self.controller.get_time()[1]
self.segment_selector = pn.widgets.Select(
name="",
options=[f"Segment {i}" for i in range(self.controller.num_segments)],
value=f"Segment {self.segment_index}",
value=f"Segment {segment_index}",
)
self.segment_selector.param.watch(self._panel_change_segment, 'value')

Expand Down Expand Up @@ -381,8 +401,8 @@ def _panel_make_layout(self):
self.scatter_fig.toolbar.active_drag = None
self.scatter_fig.xaxis.axis_label = "Time (s)"
self.scatter_fig.yaxis.axis_label = self.y_label
time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency
self.scatter_fig.x_range = Range1d(0., time_max)
t_start, t_stop = self.controller.get_t_start_t_stop()
self.scatter_fig.x_range = Range1d(t_start, t_stop)

# Add SelectionGeometry event handler to capture lasso vertices
self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry)
Expand Down Expand Up @@ -445,11 +465,17 @@ def _panel_refresh(self):
ys = []
colors = []

segment_index = self.controller.get_time()[1]
# get view segment index from segment selector
segment_index_from_selector = self.segment_selector.options.index(self.segment_selector.value)
if segment_index != segment_index_from_selector:
self.segment_selector.value = f"Segment {segment_index}"

visible_unit_ids = self.controller.get_visible_unit_ids()
for unit_id in visible_unit_ids:
spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data(
unit_id,
seg_index=self.segment_index
segment_index=segment_index
)
color = self.get_unit_color(unit_id)
xs.extend(spike_times)
Expand All @@ -470,6 +496,9 @@ def _panel_refresh(self):
line_width=2,
)
self.hist_lines.append(hist_lines)
t_start, t_end = self.controller.get_t_start_t_stop()
self.scatter_fig.x_range.start = t_start
self.scatter_fig.x_range.end = t_end

self._max_count = max_count

Expand Down Expand Up @@ -503,10 +532,13 @@ def _panel_on_select_button(self, event):

def _panel_change_segment(self, event):
self._current_selected = 0
self.segment_index = int(self.segment_selector.value.split()[-1])
time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency
self.scatter_fig.x_range.end = time_max
segment_index = int(self.segment_selector.value.split()[-1])
self.controller.set_time(segment_index=segment_index)
t_start, t_end = self.controller.get_t_start_t_stop()
self.scatter_fig.x_range.start = t_start
self.scatter_fig.x_range.end = t_end
self.refresh()
self.notify_time_info_updated()

def _on_panel_selection_geometry(self, event):
"""
Expand All @@ -524,16 +556,16 @@ def _on_panel_selection_geometry(self, event):
return

# Append the current polygon to the lasso vertices if shift is held
seg_index = self.segment_index
if self._lasso_vertices[seg_index] is None:
self._lasso_vertices[seg_index] = []
segment_index = self.controller.get_time()[1]
if self._lasso_vertices[segment_index] is None:
self._lasso_vertices[segment_index] = []
if len(selected) > self._current_selected:
self._current_selected = len(selected)
# Store the current polygon for the current segment
self._lasso_vertices[seg_index].append(polygon)
self._lasso_vertices[segment_index].append(polygon)
keep_already_selected = True
else:
self._lasso_vertices[seg_index] = [polygon]
self._lasso_vertices[segment_index] = [polygon]
keep_already_selected = False

self.select_all_spikes_from_lasso(keep_already_selected)
Expand All @@ -551,7 +583,8 @@ def _panel_update_selected_spikes(self):
selected_spike_indices = np.intersect1d(selected_spike_indices, self.plotted_inds)
if len(selected_spike_indices) > 0:
# map absolute indices to visible spikes
sl = self.controller.segment_slices[self.segment_index]
segment_index = self.controller.get_time()[1]
sl = self.controller.segment_slices[segment_index]
spikes_in_seg = self.controller.spikes[sl]
visible_mask = np.zeros(len(spikes_in_seg), dtype=bool)
for unit_index, unit_id in self.controller.iter_visible_units():
Expand All @@ -573,7 +606,8 @@ def _panel_on_spike_selection_changed(self):
return
elif len(selected_indices) == 1:
selected_segment = self.controller.spikes[selected_indices[0]]['segment_index']
if selected_segment != self.segment_index:
segment_index = self.controller.get_time()[1]
if selected_segment != segment_index:
self.segment_selector.value = f"Segment {selected_segment}"
self._panel_change_segment(None)
# update selected spikes
Expand Down
Loading