Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 188 additions & 139 deletions spikeinterface_gui/controller.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions spikeinterface_gui/correlogramview.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def _qt_make_layout(self):
self.grid = pg.GraphicsLayoutWidget()
self.layout.addWidget(self.grid)

def _qt_reinitialize(self):
self.ccg, self.bins = self.controller.get_correlograms()
self.figure_cache = {}
self._qt_refresh()

def _qt_refresh(self):
import pyqtgraph as pg
Expand Down Expand Up @@ -117,6 +121,11 @@ def _panel_make_layout(self):
sizing_mode="stretch_both",
)

def _panel_reinitialize(self):
self.ccg, self.bins = self.controller.get_correlograms()
self.figure_cache = {}
self._panel_refresh()

def _panel_refresh(self):
import panel as pn
import bokeh.plotting as bpl
Expand Down
41 changes: 37 additions & 4 deletions spikeinterface_gui/curationview.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def _qt_make_layout(self):
but = QT.QPushButton("Save in analyzer")
tb.addWidget(but)
but.clicked.connect(self.save_in_analyzer)

but_apply = QT.QPushButton("Apply curation")
tb.addWidget(but_apply)
but_apply.clicked.connect(self.apply_curation_to_analyzer)

but = QT.QPushButton("Export JSON")
but.clicked.connect(self._qt_export_json)
tb.addWidget(but)
Expand Down Expand Up @@ -278,6 +283,10 @@ def on_manual_curation_updated(self):
def save_in_analyzer(self):
self.controller.save_curation_in_analyzer()

def apply_curation_to_analyzer(self):
with self.busy_cursor():
self.controller.apply_curation()

def _qt_export_json(self):
from .myqt import QT
fd = QT.QFileDialog(fileMode=QT.QFileDialog.AnyFile, acceptMode=QT.QFileDialog.AcceptSave)
Expand All @@ -286,10 +295,23 @@ def _qt_export_json(self):
fd.setViewMode(QT.QFileDialog.Detail)
if fd.exec_():
json_file = Path(fd.selectedFiles()[0])
curation_model = self.controller.construct_final_curation()
with json_file.open("w") as f:
f.write(curation_model.model_dump_json(indent=4))
self.controller.current_curation_saved = True
if len(self.controller.applied_curations) == 0:
curation_model = self.controller.construct_final_curation()
with json_file.open("w") as f:
f.write(curation_model.model_dump_json(indent=4))
self.controller.current_curation_saved = True
else:
# Keep this here until `SeqentialCuration` in release of spikeinterface
from spikeinterface.curation.curation_model import SequentialCuration

current_curation_model = self.controller.construct_final_curation()
applied_curations = self.controller.applied_curations
current_and_applied_curations = applied_curations + [current_curation_model.model_dump()]

sequential_curation_model = SequentialCuration(curation_steps=current_and_applied_curations)
with json_file.open("w") as f:
f.write(sequential_curation_model.model_dump_json(indent=4))
self.controller.current_curation_saved = True

# PANEL
def _panel_make_layout(self):
Expand Down Expand Up @@ -360,6 +382,13 @@ def _panel_make_layout(self):
)
save_button.on_click(self._panel_save_in_analyzer)

apply_button = pn.widgets.Button(
name="Apply curation",
button_type="primary",
height=30
)
apply_button.on_click(self._panel_apply_curation_to_analyzer)

download_button = pn.widgets.FileDownload(
button_type="primary",
filename="curation.json",
Expand Down Expand Up @@ -391,6 +420,7 @@ def _panel_make_layout(self):
buttons_save = pn.Row(
save_button,
download_button,
apply_button,
submit_button,
sizing_mode="stretch_width",
)
Expand Down Expand Up @@ -522,6 +552,9 @@ def _panel_restore_units(self, event):
def _panel_unmerge(self, event):
self.unmerge()

def _panel_apply_curation_to_analyzer(self, event):
self.apply_curation_to_analyzer()

def _panel_save_in_analyzer(self, event):
self.save_in_analyzer()
self.refresh()
Expand Down
20 changes: 19 additions & 1 deletion spikeinterface_gui/mainsettingsview.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
{'name': 'max_visible_units', 'type': 'int', 'value' : 10 },
{'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit',
'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']},
{'name': 'use_times', 'type': 'bool', 'value': False}
{'name': 'use_times', 'type': 'bool', 'value': False},
{'name': 'merge_new_id_strategy', 'type': 'list', 'limits' : ['take_first', 'append', 'join']},
{'name': 'split_new_id_strategy', 'type': 'list', 'limits' : ['append', 'split']},
]


Expand Down Expand Up @@ -45,6 +47,12 @@ def on_use_times(self):
self.controller.update_time_info()
self.notify_use_times_updated()

def on_merge_new_id_strategy(self):
self.controller.main_settings['merge_new_id_strategy'] = self.main_settings['merge_new_id_strategy']

def on_split_new_id_strategy(self):
self.controller.main_settings['split_new_id_strategy'] = self.main_settings['split_new_id_strategy']

def save_current_settings(self, event=None):

backend = self.controller.backend
Expand Down Expand Up @@ -106,6 +114,8 @@ def _qt_make_layout(self):
self.main_settings.param('max_visible_units').sigValueChanged.connect(self.on_max_visible_units_changed)
self.main_settings.param('color_mode').sigValueChanged.connect(self.on_change_color_mode)
self.main_settings.param('use_times').sigValueChanged.connect(self.on_use_times)
self.main_settings.param('merge_new_id_strategy').sigValueChanged.connect(self.on_merge_new_id_strategy)
self.main_settings.param('split_new_id_strategy').sigValueChanged.connect(self.on_split_new_id_strategy)

def qt_make_settings_dict(self, view):
"""For a given view, return the current settings in a dict"""
Expand Down Expand Up @@ -141,6 +151,8 @@ def _panel_make_layout(self):
self.main_settings._parameterized.param.watch(self._panel_on_max_visible_units_changed, 'max_visible_units')
self.main_settings._parameterized.param.watch(self._panel_on_change_color_mode, 'color_mode')
self.main_settings._parameterized.param.watch(self._panel_on_use_times, 'use_times')
self.main_settings._parameterized.param.watch(self._panel_on_merge_new_id_strategy, 'merge_new_id_strategy')
self.main_settings._parameterized.param.watch(self._panel_on_split_new_id_strategy, 'split_new_id_strategy')
self.layout = pn.Column(self.save_setting_button, self.main_settings_layout, sizing_mode="stretch_both")

def panel_make_settings_dict(self, view):
Expand All @@ -160,6 +172,12 @@ def _panel_on_max_visible_units_changed(self, event):
def _panel_on_change_color_mode(self, event):
self.on_change_color_mode()

def _panel_on_merge_new_id_strategy(self, event):
self.on_merge_new_id_strategy()

def _panel_on_split_new_id_strategy(self, event):
self.on_split_new_id_strategy()

def _panel_on_use_times(self, event):
self.on_use_times()

Expand Down
10 changes: 10 additions & 0 deletions spikeinterface_gui/mergeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ def _qt_on_spike_selection_changed(self):
def _qt_on_unit_visibility_changed(self):
pass

def _qt_reinitialize(self):
self.proposed_merge_unit_groups = []
self.merge_info = {}
self._qt_refresh()

## PANEL
def _panel_make_layout(self):
import panel as pn
Expand Down Expand Up @@ -403,6 +408,11 @@ def _panel_refresh(self):
self.table.on_click(self._panel_on_click)
self.table_area.update(self.table)

def _panel_reinitialize(self):
self.proposed_merge_unit_groups = []
self.merge_info = {}
self._panel_refresh()

def _panel_compute_merges(self, event):
self._compute_merges()

Expand Down
28 changes: 23 additions & 5 deletions spikeinterface_gui/probeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ def _qt_make_layout(self):

self.roi_units.sigRegionChangeFinished.connect(self._qt_on_roi_units_changed)

def _qt_reinitialize(self):
import pyqtgraph as pg
unit_positions = self.controller.unit_positions
brush = [self.get_unit_color(u) for u in self.controller.unit_ids]
self.scatter = pg.ScatterPlotItem(pos=unit_positions, pxMode=False, size=10, brush=brush)

xlim0, xlim1, ylim0, ylim1 = self.get_view_bounds()
self.plot.setXRange(xlim0, xlim1)
self.plot.setYRange(ylim0, ylim1)
self._qt_refresh()

def _qt_refresh(self):
current_unit_positions = self.controller.unit_positions
# if not np.array_equal(current_unit_positions, self._unit_positions):
Expand Down Expand Up @@ -479,11 +490,14 @@ def _panel_make_layout(self):
self.should_resize_unit_circle = None

# Main layout
self.layout = pn.Column(
self.figure,
styles={"display": "flex", "flex-direction": "column"},
sizing_mode="stretch_both",
)
if self.layout is None:
self.layout = pn.Column(
self.figure,
styles={"display": "flex", "flex-direction": "column"},
sizing_mode="stretch_both",
)
else:
self.layout.objects = [self.figure]

def _panel_refresh(self):
# Only update unit positions if they actually changed
Expand Down Expand Up @@ -529,6 +543,10 @@ def _panel_refresh(self):
self.y_range.start = y_min - margin
self.y_range.end = y_max + margin

def _panel_reinitialize(self):
self._panel_make_layout()
self._refresh()

def _panel_update_unit_glyphs(self):
# Get current data from source
current_alphas = self.unit_glyphs.data_source.data['alpha']
Expand Down
4 changes: 4 additions & 0 deletions spikeinterface_gui/spikeamplitudeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def _qt_refresh(self):
self.plot2.removeItem(n)
self.noise_harea = []

def _reinitialize(self):
self.spike_data = self.controller.spike_amplitudes
self._qt_refresh()

def _qt_add_noise_area(self):
import pyqtgraph as pg

Expand Down
3 changes: 3 additions & 0 deletions spikeinterface_gui/spikedepthview.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def __init__(self, controller=None, parent=None, backend="qt"):
spike_data=spike_data,
)

def _reinitialize(self):
self.spike_data = self.controller.spike_depths
self._refresh()


SpikeDepthView._gui_help_txt = """
Expand Down
32 changes: 22 additions & 10 deletions spikeinterface_gui/unitlistview.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def _qt_make_layout(self):
self.shortcut_noise.setKey(QT.QKeySequence('n'))
self.shortcut_noise.activated.connect(lambda: self._qt_set_default_label('noise'))

def _qt_reinitialize(self):
self._qt_full_table_refresh()
self._qt_refresh()

def _qt_on_column_moved(self, logical_index, old_visual_index, new_visual_index):
# Update stored column order
Expand Down Expand Up @@ -584,16 +587,22 @@ def _panel_make_layout(self):
shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts)
shortcuts_component.on_msg(self._panel_handle_shortcut)

self.layout = pn.Column(
pn.Row(
self.info_text,
),
buttons,
sizing_mode="stretch_width",
)
if self.layout is None:
self.layout = pn.Column(
pn.Row(
self.info_text,
),
buttons,
sizing_mode="stretch_width",
)

self.layout.append(self.table)
self.layout.append(shortcuts_component)
self.layout.append(self.table)
self.layout.append(shortcuts_component)
else:
self.layout[0][0] = self.info_text
self.layout[1] = buttons
self.layout[2] = self.table
self.layout[3] = shortcuts_component

self.table.tabulator.on_edit(self._panel_on_edit)

Expand Down Expand Up @@ -650,6 +659,10 @@ def _panel_refresh(self):
# refresh header
self._panel_refresh_header()

def _panel_reinitialize(self):
self._panel_make_layout()
self._panel_refresh()

def _panel_refresh_header(self):
unit_ids = self.controller.unit_ids
n1 = len(unit_ids)
Expand All @@ -675,7 +688,6 @@ def _panel_merge_units_callback(self, event):
self.notifier.notify_active_view_updated()

def _panel_on_visible_checkbox_toggled(self, row):
# print("checkbox toggled on row", row)
unit_ids = self.table.value.index.values
selected_unit_id = unit_ids[row]
self.controller.set_unit_visibility(selected_unit_id, not self.controller.get_unit_visibility(selected_unit_id))
Expand Down
31 changes: 31 additions & 0 deletions spikeinterface_gui/utils_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,34 @@ def get_present_zones_in_half_of_layout(layout_zone, shift):
is_present = [views is not None and len(views) > 0 for views in half_dict.values()]
present_zones = set(np.array(list(half_dict.keys()))[np.array(is_present)])
return present_zones


def add_new_unit_ids_to_curation_dict(curation_dict, sorting, split_new_id_strategy, merge_new_id_strategy):

from spikeinterface.core.sorting_tools import generate_unit_ids_for_split, generate_unit_ids_for_merge_group
from spikeinterface.curation.curation_model import CurationModel

curation_model = CurationModel(**curation_dict)
old_unit_ids = curation_model.unit_ids

print(f"{sorting=}")

if len(curation_model.splits) > 0:
print(f"{curation_model.splits=}")
print(f"{split_new_id_strategy=}", flush=True)
unit_splits = {split.unit_id: split.get_full_spike_indices(sorting) for split in curation_model.splits}
new_split_unit_ids = generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy=split_new_id_strategy)
print(f"{new_split_unit_ids=}")
for split_index, new_unit_ids in enumerate(new_split_unit_ids):
curation_dict['splits'][split_index]['new_unit_ids'] = new_unit_ids

if len(curation_model.merges) > 0:
merge_unit_groups = [m.unit_ids for m in curation_model.merges]
print(f"{merge_unit_groups=}")
print(f"{merge_new_id_strategy=}", flush=True)
new_merge_unit_ids = generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy=merge_new_id_strategy)
print(f"{new_merge_unit_ids=}")
for merge_index, new_unit_id in enumerate(new_merge_unit_ids):
curation_dict['merges'][merge_index]['new_unit_id'] = new_unit_id

return curation_dict
Loading