diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 3034b4f..652e502 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -1,4 +1,5 @@ import time +from copy import deepcopy import numpy as np @@ -12,11 +13,12 @@ import spikeinterface.qualitymetrics from spikeinterface.core.sorting_tools import spike_vector_to_indices from spikeinterface.core.core_tools import check_json -from spikeinterface.curation import validate_curation_dict +from spikeinterface.curation import validate_curation_dict, apply_curation from spikeinterface.curation.curation_model import CurationModel from spikeinterface.widgets.utils import make_units_table_from_analyzer from .curation_tools import add_merge, default_label_definitions, empty_curation_data +from .utils_global import add_new_unit_ids_to_curation_dict spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'), ('channel_index', 'int64'), ('segment_index', 'int64'), @@ -26,7 +28,9 @@ _default_main_settings = dict( max_visible_units=10, color_mode='color_by_unit', - use_times=False + use_times=False, + merge_new_id_strategy = 'take_first', + split_new_id_strategy = 'append', ) from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties @@ -44,6 +48,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self.backend = backend self.disable_save_settings_button = disable_save_settings_button self.current_curation_saved = True + self.applied_curations = [] if self.backend == "qt": from .backend_qt import SignalHandler @@ -54,18 +59,130 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self.signal_handler = SignalHandler(self, parent=parent) self.with_traces = with_traces + self.main_settings = _default_main_settings.copy() + self.save_on_compute = save_on_compute + self.verbose = verbose + + self.original_analyzer = None + self.set_analyzer_info(analyzer) + self.units_table = make_units_table_from_analyzer(self.analyzer, extra_properties=extra_unit_properties) + + self.extra_unit_properties_names = list(extra_unit_properties.keys()) + if displayed_unit_properties is None: + displayed_unit_properties = list(_default_displayed_unit_properties) + if extra_unit_properties is not None: + displayed_unit_properties += list(extra_unit_properties.keys()) + displayed_unit_properties = [v for v in displayed_unit_properties if v in self.units_table.columns] + self.displayed_unit_properties = displayed_unit_properties + + # spikeinterface handle colors in matplotlib style tuple values in range (0,1) + self.refresh_colors() + + self._potential_merges = None + self.curation = curation + # TODO: Reload the dictionary if it already exists + if self.curation: + # rules: + # * if user sends curation_data, then it is used + # * otherwise, if curation_data already exists in folder it is used + # * otherwise create an empty one + + if curation_data is not None: + # validate the curation data + format_version = curation_data.get("format_version", None) + # assume version 2 if not present + if format_version is None: + raise ValueError("Curation data format version is missing and is required in the curation data.") + try: + validate_curation_dict(curation_data) + except Exception as e: + raise ValueError(f"Invalid curation data.\nError: {e}") + + if curation_data.get("merges") is None: + curation_data["merges"] = [] + else: + # here we reset the merges for better formatting (str) + existing_merges = curation_data["merges"] + new_merges = [] + for m in existing_merges: + if "unit_ids" not in m: + continue + if len(m["unit_ids"]) < 2: + continue + new_merges = add_merge(new_merges, m["unit_ids"]) + curation_data["merges"] = new_merges + if curation_data.get("splits") is None: + curation_data["splits"] = [] + if curation_data.get("removed") is None: + curation_data["removed"] = [] + + elif self.analyzer.format == "binary_folder": + json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json" + if json_file.exists(): + with open(json_file, "r") as f: + curation_data = json.load(f) + + elif self.analyzer.format == "zarr": + import zarr + zarr_root = zarr.open(self.analyzer.folder, mode='r') + if "spikeinterface_gui" in zarr_root.keys() and "curation_data" in zarr_root["spikeinterface_gui"].attrs.keys(): + curation_data = zarr_root["spikeinterface_gui"].attrs["curation_data"] + + if curation_data is None: + curation_data = deepcopy(empty_curation_data) + curation_data["label_definitions"] = default_label_definitions.copy() + + if curation_data.get("discard_spikes") is None: + curation_data["discard_spikes"] = [] + + self.curation_data = curation_data + + if "label_definitions" not in self.curation_data: + if label_definitions is not None: + self.curation_data["label_definitions"] = label_definitions + + self.has_default_quality_labels = False + if "quality" in self.curation_data["label_definitions"]: + curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"] + default_quality_labels = default_label_definitions["quality"]["label_options"] + if set(curation_dict_quality_labels) == set(default_quality_labels): + if self.verbose: + print('Curation quality labels are the default ones') + self.has_default_quality_labels = True + + def check_is_view_possible(self, view_name): + from .viewlist import get_all_possible_views + possible_class_views = get_all_possible_views() + view_class = possible_class_views[view_name] + if view_class._depend_on is not None: + depencies_ok = all(self.has_extension(k) for k in view_class._depend_on) + if not depencies_ok: + if self.verbose: + print(view_name, 'does not have all dependencies', view_class._depend_on) + return False + return True + + def declare_a_view(self, new_view): + assert new_view not in self.views, 'view already declared {}'.format(self) + self.views.append(new_view) + self.signal_handler.connect_view(new_view) + + @property + def channel_ids(self): + return self.analyzer.channel_ids + + @property + def unit_ids(self): + return self.analyzer.unit_ids + + def set_analyzer_info(self, analyzer): self.analyzer = analyzer assert self.analyzer.get_extension("random_spikes") is not None self.return_in_uV = self.analyzer.return_in_uV - self.save_on_compute = save_on_compute - - self.verbose = verbose t0 = time.perf_counter() - self.main_settings = _default_main_settings.copy() - self.num_channels = self.analyzer.get_num_channels() # this now private and shoudl be acess using function self._visible_unit_ids = [self.unit_ids[0]] @@ -79,7 +196,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self.analyzer_sparsity = self.analyzer.sparsity # Mandatory extensions: computation forced - if verbose: + if self.verbose: print('\tLoading templates') temp_ext = self.analyzer.get_extension("templates") if temp_ext is None: @@ -93,7 +210,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save else: self.templates_std = None - if verbose: + if self.verbose: print('\tLoading unit_locations') ext = analyzer.get_extension('unit_locations') if ext is None: @@ -103,7 +220,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self.unit_positions = ext.get_data()[:, :2] # Optional extensions : can be None or skipped - if verbose: + if self.verbose: print('\tLoading noise_levels') ext = analyzer.get_extension('noise_levels') if ext is None and self.has_extension('recording'): @@ -111,12 +228,12 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save ext = analyzer.compute_one_extension('noise_levels') self.noise_levels = ext.get_data() if ext is not None else None - if "quality_metrics" in skip_extensions: + if "quality_metrics" in self.skip_extensions: if self.verbose: print('\tSkipping quality_metrics') self.metrics = None else: - if verbose: + if self.verbose: print('\tLoading quality_metrics') qm_ext = analyzer.get_extension('quality_metrics') if qm_ext is not None: @@ -124,12 +241,12 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save else: self.metrics = None - if "spike_amplitudes" in skip_extensions: + if "spike_amplitudes" in self.skip_extensions: if self.verbose: print('\tSkipping spike_amplitudes') self.spike_amplitudes = None else: - if verbose: + if self.verbose: print('\tLoading spike_amplitudes') sa_ext = analyzer.get_extension('spike_amplitudes') if sa_ext is not None: @@ -137,12 +254,12 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save else: self.spike_amplitudes = None - if "spike_locations" in skip_extensions: + if "spike_locations" in self.skip_extensions: if self.verbose: print('\tSkipping spike_locations') self.spike_depths = None else: - if verbose: + if self.verbose: print('\tLoading spike_locations') sl_ext = analyzer.get_extension('spike_locations') if sl_ext is not None: @@ -150,13 +267,13 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save else: self.spike_depths = None - if "correlograms" in skip_extensions: + if "correlograms" in self.skip_extensions: if self.verbose: print('\tSkipping correlograms') self.correlograms = None self.correlograms_bins = None else: - if verbose: + if self.verbose: print('\tLoading correlograms') ccg_ext = analyzer.get_extension('correlograms') if ccg_ext is not None: @@ -164,13 +281,13 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save else: self.correlograms, self.correlograms_bins = None, None - if "isi_histograms" in skip_extensions: + if "isi_histograms" in self.skip_extensions: if self.verbose: print('\tSkipping isi_histograms') self.isi_histograms = None self.isi_bins = None else: - if verbose: + if self.verbose: print('\tLoading isi_histograms') isi_ext = analyzer.get_extension('isi_histograms') if isi_ext is not None: @@ -179,11 +296,11 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self.isi_histograms, self.isi_bins = None, None self._similarity_by_method = {} - if "template_similarity" in skip_extensions: + if "template_similarity" in self.skip_extensions: if self.verbose: print('\tSkipping template_similarity') else: - if verbose: + if self.verbose: print('\tLoading template_similarity') ts_ext = analyzer.get_extension('template_similarity') if ts_ext is not None: @@ -196,12 +313,12 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save ts_ext = analyzer.compute_one_extension('template_similarity', method=method, save=save_on_compute) self._similarity_by_method[method] = ts_ext.get_data() - if "waveforms" in skip_extensions: + if "waveforms" in self.skip_extensions: if self.verbose: print('\tSkipping waveforms') self.waveforms_ext = None else: - if verbose: + if self.verbose: print('\tLoading waveforms') wf_ext = analyzer.get_extension('waveforms') if wf_ext is not None: @@ -209,20 +326,18 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save else: self.waveforms_ext = None self._pc_projections = None - if "principal_components" in skip_extensions: + if "principal_components" in self.skip_extensions: if self.verbose: print('\tSkipping principal_components') self.pc_ext = None else: - if verbose: + if self.verbose: print('\tLoading principal_components') pc_ext = analyzer.get_extension('principal_components') self.pc_ext = pc_ext - self._potential_merges = None - t1 = time.perf_counter() - if verbose: + if self.verbose: print('Loading extensions took', t1 - t0) t0 = time.perf_counter() @@ -233,9 +348,6 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self.num_segments = self.analyzer.get_num_segments() self.sampling_frequency = self.analyzer.sampling_frequency - # spikeinterface handle colors in matplotlib style tuple values in range (0,1) - self.refresh_colors() - # at init, we set the visible channels as the sparsity of the first unit if self.analyzer_sparsity is not None: self.visible_channel_inds = self.analyzer_sparsity.unit_id_to_channel_indices[self.unit_ids[0]].astype("int64") @@ -287,7 +399,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self._spike_index_by_units[unit_id] = np.concatenate(inds) t1 = time.perf_counter() - if verbose: + if self.verbose: print('Gathering all spikes took', t1 - t0) self._spike_visible_indices = np.array([], dtype='int64') @@ -296,110 +408,9 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self._traces_cached = {} - self.units_table = make_units_table_from_analyzer(analyzer, extra_properties=extra_unit_properties) - if displayed_unit_properties is None: - displayed_unit_properties = list(_default_displayed_unit_properties) - if extra_unit_properties is not None: - displayed_unit_properties += list(extra_unit_properties.keys()) - displayed_unit_properties = [v for v in displayed_unit_properties if v in self.units_table.columns] - self.displayed_unit_properties = displayed_unit_properties - # set default time info self.update_time_info() - self.curation = curation - # TODO: Reload the dictionary if it already exists - if self.curation: - # rules: - # * if user sends curation_data, then it is used - # * otherwise, if curation_data already exists in folder it is used - # * otherwise create an empty one - - if curation_data is not None: - # validate the curation data - format_version = curation_data.get("format_version", None) - # assume version 2 if not present - if format_version is None: - raise ValueError("Curation data format version is missing and is required in the curation data.") - try: - validate_curation_dict(curation_data) - except Exception as e: - raise ValueError(f"Invalid curation data.\nError: {e}") - - if curation_data.get("merges") is None: - curation_data["merges"] = [] - else: - # here we reset the merges for better formatting (str) - existing_merges = curation_data["merges"] - new_merges = [] - for m in existing_merges: - if "unit_ids" not in m: - continue - if len(m["unit_ids"]) < 2: - continue - new_merges = add_merge(new_merges, m["unit_ids"]) - curation_data["merges"] = new_merges - if curation_data.get("splits") is None: - curation_data["splits"] = [] - if curation_data.get("removed") is None: - curation_data["removed"] = [] - - elif self.analyzer.format == "binary_folder": - json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json" - if json_file.exists(): - with open(json_file, "r") as f: - curation_data = json.load(f) - - elif self.analyzer.format == "zarr": - import zarr - zarr_root = zarr.open(self.analyzer.folder, mode='r') - if "spikeinterface_gui" in zarr_root.keys() and "curation_data" in zarr_root["spikeinterface_gui"].attrs.keys(): - curation_data = zarr_root["spikeinterface_gui"].attrs["curation_data"] - - if curation_data is None: - curation_data = empty_curation_data.copy() - - self.curation_data = curation_data - - self.has_default_quality_labels = False - if "label_definitions" not in self.curation_data: - if label_definitions is not None: - self.curation_data["label_definitions"] = label_definitions - else: - self.curation_data["label_definitions"] = default_label_definitions.copy() - - if "quality" in self.curation_data["label_definitions"]: - curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"] - default_quality_labels = default_label_definitions["quality"]["label_options"] - if set(curation_dict_quality_labels) == set(default_quality_labels): - if self.verbose: - print('Curation quality labels are the default ones') - self.has_default_quality_labels = True - - def check_is_view_possible(self, view_name): - from .viewlist import get_all_possible_views - possible_class_views = get_all_possible_views() - view_class = possible_class_views[view_name] - if view_class._depend_on is not None: - depencies_ok = all(self.has_extension(k) for k in view_class._depend_on) - if not depencies_ok: - if self.verbose: - print(view_name, 'does not have all dependencies', view_class._depend_on) - return False - return True - - def declare_a_view(self, new_view): - assert new_view not in self.views, 'view already declared {}'.format(self) - self.views.append(new_view) - self.signal_handler.connect_view(new_view) - - @property - def channel_ids(self): - return self.analyzer.channel_ids - - @property - def unit_ids(self): - return self.analyzer.unit_ids def get_time(self): """ @@ -499,15 +510,22 @@ def get_information_txt(self): return txt - def refresh_colors(self): + def refresh_colors(self, existing_colors=None): if self.backend == "qt": self._cached_qcolors = {} elif self.backend == "panel": pass if self.main_settings['color_mode'] == 'color_by_unit': - self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', - shuffle=True, seed=42) + unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', + shuffle=True, seed=42) + if existing_colors is None: + self.colors = unit_colors + else: + for unit_id, unit_color in unit_colors.items(): + if unit_id not in self.colors.keys(): + self.colors[unit_id] = unit_color + elif self.main_settings['color_mode'] == 'color_only_visible': unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', shuffle=True, seed=42) @@ -797,18 +815,49 @@ def compute_auto_merge(self, **params): ) return merge_unit_groups, extra - + def curation_can_be_saved(self): return self.analyzer.format != "memory" - def construct_final_curation(self): + def construct_final_curation(self, with_explicit_new_unit_ids=False): d = dict() d["format_version"] = "2" d["unit_ids"] = self.unit_ids.tolist() d.update(self.curation_data.copy()) + if with_explicit_new_unit_ids: + split_new_id_strategy = self.main_settings.get('split_new_id_strategy') + merge_new_id_strategy = self.main_settings.get('merge_new_id_strategy') + d = add_new_unit_ids_to_curation_dict(d, self.analyzer.sorting, split_new_id_strategy=split_new_id_strategy, merge_new_id_strategy=merge_new_id_strategy) model = CurationModel(**d) return model + def apply_curation(self): + + if self.original_analyzer is None: + self.original_analyzer = deepcopy(self.analyzer) + self.original_analyzer.extensions = {} + + curation = self.construct_final_curation(with_explicit_new_unit_ids=True) + curated_analyzer = apply_curation(self.analyzer, curation) + self.applied_curations.append(curation.model_dump()) + self.remove_curation() + + self.set_analyzer_info(curated_analyzer) + + # for now, don't show externally provided properties after curation + self.displayed_unit_properties = [displayed_property for displayed_property in self.displayed_unit_properties if displayed_property not in self.extra_unit_properties_names] + self.units_table = make_units_table_from_analyzer(self.analyzer) + self.refresh_colors(existing_colors=self.colors) + + for view in self.views: + view.reinitialize() + + def remove_curation(self): + label_definitioins = self.curation_data.get("label_definitions", None) + curation_data = deepcopy(empty_curation_data) + curation_data["label_definitions"] = label_definitioins + self.curation_data = curation_data + def save_curation_in_analyzer(self): if self.analyzer.format == "memory": print("Analyzer is an in-memory object. Cannot save curation file in it.") diff --git a/spikeinterface_gui/correlogramview.py b/spikeinterface_gui/correlogramview.py index 9ca6fa6..eb2664a 100644 --- a/spikeinterface_gui/correlogramview.py +++ b/spikeinterface_gui/correlogramview.py @@ -48,6 +48,10 @@ def _qt_make_layout(self): self.grid = pg.GraphicsLayoutWidget() self.layout.addWidget(self.grid) + def _qt_reinitialize(self): + self.ccg, self.bins = self.controller.get_correlograms() + self.figure_cache = {} + self._qt_refresh() def _qt_refresh(self): import pyqtgraph as pg @@ -117,6 +121,11 @@ def _panel_make_layout(self): sizing_mode="stretch_both", ) + def _panel_reinitialize(self): + self.ccg, self.bins = self.controller.get_correlograms() + self.figure_cache = {} + self._panel_refresh() + def _panel_refresh(self): import panel as pn import bokeh.plotting as bpl diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index 2018d0f..59402e9 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -73,6 +73,11 @@ def _qt_make_layout(self): but = QT.QPushButton("Save in analyzer") tb.addWidget(but) but.clicked.connect(self.save_in_analyzer) + + but_apply = QT.QPushButton("Apply curation") + tb.addWidget(but_apply) + but_apply.clicked.connect(self.apply_curation_to_analyzer) + but = QT.QPushButton("Export JSON") but.clicked.connect(self._qt_export_json) tb.addWidget(but) @@ -278,6 +283,10 @@ def on_manual_curation_updated(self): def save_in_analyzer(self): self.controller.save_curation_in_analyzer() + def apply_curation_to_analyzer(self): + with self.busy_cursor(): + self.controller.apply_curation() + def _qt_export_json(self): from .myqt import QT fd = QT.QFileDialog(fileMode=QT.QFileDialog.AnyFile, acceptMode=QT.QFileDialog.AcceptSave) @@ -286,10 +295,23 @@ def _qt_export_json(self): fd.setViewMode(QT.QFileDialog.Detail) if fd.exec_(): json_file = Path(fd.selectedFiles()[0]) - curation_model = self.controller.construct_final_curation() - with json_file.open("w") as f: - f.write(curation_model.model_dump_json(indent=4)) - self.controller.current_curation_saved = True + if len(self.controller.applied_curations) == 0: + curation_model = self.controller.construct_final_curation() + with json_file.open("w") as f: + f.write(curation_model.model_dump_json(indent=4)) + self.controller.current_curation_saved = True + else: + # Keep this here until `SeqentialCuration` in release of spikeinterface + from spikeinterface.curation.curation_model import SequentialCuration + + current_curation_model = self.controller.construct_final_curation() + applied_curations = self.controller.applied_curations + current_and_applied_curations = applied_curations + [current_curation_model.model_dump()] + + sequential_curation_model = SequentialCuration(curation_steps=current_and_applied_curations) + with json_file.open("w") as f: + f.write(sequential_curation_model.model_dump_json(indent=4)) + self.controller.current_curation_saved = True # PANEL def _panel_make_layout(self): @@ -360,6 +382,13 @@ def _panel_make_layout(self): ) save_button.on_click(self._panel_save_in_analyzer) + apply_button = pn.widgets.Button( + name="Apply curation", + button_type="primary", + height=30 + ) + apply_button.on_click(self._panel_apply_curation_to_analyzer) + download_button = pn.widgets.FileDownload( button_type="primary", filename="curation.json", @@ -391,6 +420,7 @@ def _panel_make_layout(self): buttons_save = pn.Row( save_button, download_button, + apply_button, submit_button, sizing_mode="stretch_width", ) @@ -522,6 +552,9 @@ def _panel_restore_units(self, event): def _panel_unmerge(self, event): self.unmerge() + def _panel_apply_curation_to_analyzer(self, event): + self.apply_curation_to_analyzer() + def _panel_save_in_analyzer(self, event): self.save_in_analyzer() self.refresh() diff --git a/spikeinterface_gui/mainsettingsview.py b/spikeinterface_gui/mainsettingsview.py index 88ebec7..05e56bc 100644 --- a/spikeinterface_gui/mainsettingsview.py +++ b/spikeinterface_gui/mainsettingsview.py @@ -8,7 +8,9 @@ {'name': 'max_visible_units', 'type': 'int', 'value' : 10 }, {'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit', 'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']}, - {'name': 'use_times', 'type': 'bool', 'value': False} + {'name': 'use_times', 'type': 'bool', 'value': False}, + {'name': 'merge_new_id_strategy', 'type': 'list', 'limits' : ['take_first', 'append', 'join']}, + {'name': 'split_new_id_strategy', 'type': 'list', 'limits' : ['append', 'split']}, ] @@ -45,6 +47,12 @@ def on_use_times(self): self.controller.update_time_info() self.notify_use_times_updated() + def on_merge_new_id_strategy(self): + self.controller.main_settings['merge_new_id_strategy'] = self.main_settings['merge_new_id_strategy'] + + def on_split_new_id_strategy(self): + self.controller.main_settings['split_new_id_strategy'] = self.main_settings['split_new_id_strategy'] + def save_current_settings(self, event=None): backend = self.controller.backend @@ -106,6 +114,8 @@ def _qt_make_layout(self): self.main_settings.param('max_visible_units').sigValueChanged.connect(self.on_max_visible_units_changed) self.main_settings.param('color_mode').sigValueChanged.connect(self.on_change_color_mode) self.main_settings.param('use_times').sigValueChanged.connect(self.on_use_times) + self.main_settings.param('merge_new_id_strategy').sigValueChanged.connect(self.on_merge_new_id_strategy) + self.main_settings.param('split_new_id_strategy').sigValueChanged.connect(self.on_split_new_id_strategy) def qt_make_settings_dict(self, view): """For a given view, return the current settings in a dict""" @@ -141,6 +151,8 @@ def _panel_make_layout(self): self.main_settings._parameterized.param.watch(self._panel_on_max_visible_units_changed, 'max_visible_units') self.main_settings._parameterized.param.watch(self._panel_on_change_color_mode, 'color_mode') self.main_settings._parameterized.param.watch(self._panel_on_use_times, 'use_times') + self.main_settings._parameterized.param.watch(self._panel_on_merge_new_id_strategy, 'merge_new_id_strategy') + self.main_settings._parameterized.param.watch(self._panel_on_split_new_id_strategy, 'split_new_id_strategy') self.layout = pn.Column(self.save_setting_button, self.main_settings_layout, sizing_mode="stretch_both") def panel_make_settings_dict(self, view): @@ -160,6 +172,12 @@ def _panel_on_max_visible_units_changed(self, event): def _panel_on_change_color_mode(self, event): self.on_change_color_mode() + def _panel_on_merge_new_id_strategy(self, event): + self.on_merge_new_id_strategy() + + def _panel_on_split_new_id_strategy(self, event): + self.on_split_new_id_strategy() + def _panel_on_use_times(self, event): self.on_use_times() diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py index 1b07c32..9867f6c 100644 --- a/spikeinterface_gui/mergeview.py +++ b/spikeinterface_gui/mergeview.py @@ -304,6 +304,11 @@ def _qt_on_spike_selection_changed(self): def _qt_on_unit_visibility_changed(self): pass + def _qt_reinitialize(self): + self.proposed_merge_unit_groups = [] + self.merge_info = {} + self._qt_refresh() + ## PANEL def _panel_make_layout(self): import panel as pn @@ -403,6 +408,11 @@ def _panel_refresh(self): self.table.on_click(self._panel_on_click) self.table_area.update(self.table) + def _panel_reinitialize(self): + self.proposed_merge_unit_groups = [] + self.merge_info = {} + self._panel_refresh() + def _panel_compute_merges(self, event): self._compute_merges() diff --git a/spikeinterface_gui/probeview.py b/spikeinterface_gui/probeview.py index 56099ce..4df724e 100644 --- a/spikeinterface_gui/probeview.py +++ b/spikeinterface_gui/probeview.py @@ -149,6 +149,17 @@ def _qt_make_layout(self): self.roi_units.sigRegionChangeFinished.connect(self._qt_on_roi_units_changed) + def _qt_reinitialize(self): + import pyqtgraph as pg + unit_positions = self.controller.unit_positions + brush = [self.get_unit_color(u) for u in self.controller.unit_ids] + self.scatter = pg.ScatterPlotItem(pos=unit_positions, pxMode=False, size=10, brush=brush) + + xlim0, xlim1, ylim0, ylim1 = self.get_view_bounds() + self.plot.setXRange(xlim0, xlim1) + self.plot.setYRange(ylim0, ylim1) + self._qt_refresh() + def _qt_refresh(self): current_unit_positions = self.controller.unit_positions # if not np.array_equal(current_unit_positions, self._unit_positions): @@ -479,11 +490,14 @@ def _panel_make_layout(self): self.should_resize_unit_circle = None # Main layout - self.layout = pn.Column( - self.figure, - styles={"display": "flex", "flex-direction": "column"}, - sizing_mode="stretch_both", - ) + if self.layout is None: + self.layout = pn.Column( + self.figure, + styles={"display": "flex", "flex-direction": "column"}, + sizing_mode="stretch_both", + ) + else: + self.layout.objects = [self.figure] def _panel_refresh(self): # Only update unit positions if they actually changed @@ -529,6 +543,10 @@ def _panel_refresh(self): self.y_range.start = y_min - margin self.y_range.end = y_max + margin + def _panel_reinitialize(self): + self._panel_make_layout() + self._refresh() + def _panel_update_unit_glyphs(self): # Get current data from source current_alphas = self.unit_glyphs.data_source.data['alpha'] diff --git a/spikeinterface_gui/spikeamplitudeview.py b/spikeinterface_gui/spikeamplitudeview.py index b76b071..3f6475e 100644 --- a/spikeinterface_gui/spikeamplitudeview.py +++ b/spikeinterface_gui/spikeamplitudeview.py @@ -48,6 +48,10 @@ def _qt_refresh(self): self.plot2.removeItem(n) self.noise_harea = [] + def _reinitialize(self): + self.spike_data = self.controller.spike_amplitudes + self._qt_refresh() + def _qt_add_noise_area(self): import pyqtgraph as pg diff --git a/spikeinterface_gui/spikedepthview.py b/spikeinterface_gui/spikedepthview.py index 0bee9df..032e4b5 100644 --- a/spikeinterface_gui/spikedepthview.py +++ b/spikeinterface_gui/spikedepthview.py @@ -17,6 +17,9 @@ def __init__(self, controller=None, parent=None, backend="qt"): spike_data=spike_data, ) + def _reinitialize(self): + self.spike_data = self.controller.spike_depths + self._refresh() SpikeDepthView._gui_help_txt = """ diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py index ef6a695..bab7210 100644 --- a/spikeinterface_gui/unitlistview.py +++ b/spikeinterface_gui/unitlistview.py @@ -129,6 +129,9 @@ def _qt_make_layout(self): self.shortcut_noise.setKey(QT.QKeySequence('n')) self.shortcut_noise.activated.connect(lambda: self._qt_set_default_label('noise')) + def _qt_reinitialize(self): + self._qt_full_table_refresh() + self._qt_refresh() def _qt_on_column_moved(self, logical_index, old_visual_index, new_visual_index): # Update stored column order @@ -584,16 +587,22 @@ def _panel_make_layout(self): shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts) shortcuts_component.on_msg(self._panel_handle_shortcut) - self.layout = pn.Column( - pn.Row( - self.info_text, - ), - buttons, - sizing_mode="stretch_width", - ) + if self.layout is None: + self.layout = pn.Column( + pn.Row( + self.info_text, + ), + buttons, + sizing_mode="stretch_width", + ) - self.layout.append(self.table) - self.layout.append(shortcuts_component) + self.layout.append(self.table) + self.layout.append(shortcuts_component) + else: + self.layout[0][0] = self.info_text + self.layout[1] = buttons + self.layout[2] = self.table + self.layout[3] = shortcuts_component self.table.tabulator.on_edit(self._panel_on_edit) @@ -650,6 +659,10 @@ def _panel_refresh(self): # refresh header self._panel_refresh_header() + def _panel_reinitialize(self): + self._panel_make_layout() + self._panel_refresh() + def _panel_refresh_header(self): unit_ids = self.controller.unit_ids n1 = len(unit_ids) @@ -675,7 +688,6 @@ def _panel_merge_units_callback(self, event): self.notifier.notify_active_view_updated() def _panel_on_visible_checkbox_toggled(self, row): - # print("checkbox toggled on row", row) unit_ids = self.table.value.index.values selected_unit_id = unit_ids[row] self.controller.set_unit_visibility(selected_unit_id, not self.controller.get_unit_visibility(selected_unit_id)) diff --git a/spikeinterface_gui/utils_global.py b/spikeinterface_gui/utils_global.py index 23fc61d..28b2666 100644 --- a/spikeinterface_gui/utils_global.py +++ b/spikeinterface_gui/utils_global.py @@ -58,3 +58,34 @@ def get_present_zones_in_half_of_layout(layout_zone, shift): is_present = [views is not None and len(views) > 0 for views in half_dict.values()] present_zones = set(np.array(list(half_dict.keys()))[np.array(is_present)]) return present_zones + + +def add_new_unit_ids_to_curation_dict(curation_dict, sorting, split_new_id_strategy, merge_new_id_strategy): + + from spikeinterface.core.sorting_tools import generate_unit_ids_for_split, generate_unit_ids_for_merge_group + from spikeinterface.curation.curation_model import CurationModel + + curation_model = CurationModel(**curation_dict) + old_unit_ids = curation_model.unit_ids + + print(f"{sorting=}") + + if len(curation_model.splits) > 0: + print(f"{curation_model.splits=}") + print(f"{split_new_id_strategy=}", flush=True) + unit_splits = {split.unit_id: split.get_full_spike_indices(sorting) for split in curation_model.splits} + new_split_unit_ids = generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy=split_new_id_strategy) + print(f"{new_split_unit_ids=}") + for split_index, new_unit_ids in enumerate(new_split_unit_ids): + curation_dict['splits'][split_index]['new_unit_ids'] = new_unit_ids + + if len(curation_model.merges) > 0: + merge_unit_groups = [m.unit_ids for m in curation_model.merges] + print(f"{merge_unit_groups=}") + print(f"{merge_new_id_strategy=}", flush=True) + new_merge_unit_ids = generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy=merge_new_id_strategy) + print(f"{new_merge_unit_ids=}") + for merge_index, new_unit_id in enumerate(new_merge_unit_ids): + curation_dict['merges'][merge_index]['new_unit_id'] = new_unit_id + + return curation_dict \ No newline at end of file diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py index b9417ff..e251557 100644 --- a/spikeinterface_gui/view_base.py +++ b/spikeinterface_gui/view_base.py @@ -38,7 +38,7 @@ def __init__(self, controller=None, parent=None, backend="qt"): create_settings(self) self.notifier = SignalNotifier(view=self) self.busy = pn.indicators.LoadingSpinner(value=True, size=20, name='busy...') - + self.layout = None make_layout() if self._settings is not None: listen_setting_changes(self) @@ -113,6 +113,14 @@ def refresh(self, **kwargs): t1 = time.perf_counter() print(f"Refresh {self.__class__.__name__} took {t1 - t0:.3f} seconds", flush=True) + def reinitialize(self, **kwargs): + if self.controller.verbose: + t0 = time.perf_counter() + self._reinitialize(**kwargs) + if self.controller.verbose: + t1 = time.perf_counter() + print(f"Reinitialize {self.__class__.__name__} took {t1 - t0:.3f} seconds", flush=True) + def compute(self, event=None): with self.busy_cursor(): self._compute() @@ -127,6 +135,12 @@ def _refresh(self, **kwargs): elif self.backend == "panel": self._panel_refresh(**kwargs) + def _reinitialize(self, **kwargs): + if self.backend == "qt": + self._qt_reinitialize(**kwargs) + elif self.backend == "panel": + self._panel_reinitialize(**kwargs) + def warning(self, warning_msg): if self.backend == "qt": self._qt_insert_warning(warning_msg) @@ -249,6 +263,9 @@ def _qt_make_layout(self): def _qt_refresh(self): raise (NotImplementedError) + + def _qt_reinitialize(self): + self._qt_refresh() def _qt_on_spike_selection_changed(self): pass @@ -309,6 +326,9 @@ def _panel_make_layout(self): def _panel_refresh(self): raise (NotImplementedError) + + def _panel_reinitialize(self): + self._panel_refresh() def _panel_on_spike_selection_changed(self): pass