Skip to content

Commit af64311

Browse files
authored
Merge pull request #157 from alejoe91/curation-format-v2
Update to curation format v2
2 parents d193a8d + 848e2e3 commit af64311

File tree

10 files changed

+72
-53
lines changed

10 files changed

+72
-53
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ classifiers = [
1616
]
1717

1818
dependencies = [
19-
"spikeinterface[full]>=0.102.3",
19+
"spikeinterface[full]>=0.103.0",
2020
"markdown"
2121
]
2222

spikeinterface_gui/controller.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
import spikeinterface.qualitymetrics
1313
from spikeinterface.core.sorting_tools import spike_vector_to_indices
1414
from spikeinterface.core.core_tools import check_json
15+
from spikeinterface.curation import validate_curation_dict
1516
from spikeinterface.widgets.utils import make_units_table_from_analyzer
1617

17-
from .curation_tools import adding_group, default_label_definitions, empty_curation_data
18+
from .curation_tools import add_merge, default_label_definitions, empty_curation_data
1819

1920
spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'),
2021
('channel_index', 'int64'), ('segment_index', 'int64'),
@@ -26,7 +27,6 @@
2627
color_mode='color_by_unit',
2728
)
2829

29-
# TODO handle return_scaled
3030
from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties
3131

3232

@@ -53,7 +53,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
5353
self.analyzer = analyzer
5454
assert self.analyzer.get_extension("random_spikes") is not None
5555

56-
self.return_scaled = True
56+
self.return_in_uV = self.analyzer.return_in_uV
5757
self.save_on_compute = save_on_compute
5858

5959
self.verbose = verbose
@@ -328,7 +328,16 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
328328
if curation_data is None:
329329
self.curation_data = empty_curation_data.copy()
330330
else:
331-
self.curation_data = curation_data
331+
# validate the curation data
332+
format_version = curation_data.get("format_version", None)
333+
# assume version 2 if not present
334+
if format_version is None:
335+
raise ValueError("Curation data format version is missing and is required in the curation data.")
336+
try:
337+
validate_curation_dict(curation_data)
338+
self.curation_data = curation_data
339+
except Exception as e:
340+
raise ValueError(f"Invalid curation data.\nError: {e}")
332341

333342
self.has_default_quality_labels = False
334343
if "label_definitions" not in self.curation_data:
@@ -547,7 +556,7 @@ def get_traces(self, trace_source='preprocessed', **kargs):
547556
elif trace_source == 'raw':
548557
raise NotImplemented
549558
# TODO get with parent recording the non process recording
550-
kargs['return_scaled'] = self.return_scaled
559+
kargs['return_in_uV'] = self.return_in_uV
551560
traces = rec.get_traces(**kargs)
552561
# put in cache for next call
553562
self._traces_cached[cache_key] = traces
@@ -678,7 +687,7 @@ def curation_can_be_saved(self):
678687

679688
def construct_final_curation(self):
680689
d = dict()
681-
d["format_version"] = "1"
690+
d["format_version"] = "2"
682691
d["unit_ids"] = self.unit_ids.tolist()
683692
d.update(self.curation_data.copy())
684693
return d
@@ -709,14 +718,14 @@ def make_manual_delete_if_possible(self, removed_unit_ids):
709718
if not self.curation:
710719
return
711720

712-
all_merged_units = sum(self.curation_data["merge_unit_groups"], [])
721+
all_merged_units = sum([m["unit_ids"] for m in self.curation_data["merges"]], [])
713722
for unit_id in removed_unit_ids:
714-
if unit_id in self.curation_data["removed_units"]:
723+
if unit_id in self.curation_data["removed"]:
715724
continue
716725
# TODO: check if unit is already in a merge group
717726
if unit_id in all_merged_units:
718727
continue
719-
self.curation_data["removed_units"].append(unit_id)
728+
self.curation_data["removed"].append(unit_id)
720729
if self.verbose:
721730
print(f"Unit {unit_id} is removed from the curation data")
722731

@@ -728,10 +737,10 @@ def make_manual_restore(self, restore_unit_ids):
728737
return
729738

730739
for unit_id in restore_unit_ids:
731-
if unit_id in self.curation_data["removed_units"]:
740+
if unit_id in self.curation_data["removed"]:
732741
if self.verbose:
733742
print(f"Unit {unit_id} is restored from the curation data")
734-
self.curation_data["removed_units"].remove(unit_id)
743+
self.curation_data["removed"].remove(unit_id)
735744

736745
def make_manual_merge_if_possible(self, merge_unit_ids):
737746
"""
@@ -750,22 +759,22 @@ def make_manual_merge_if_possible(self, merge_unit_ids):
750759
return False
751760

752761
for unit_id in merge_unit_ids:
753-
if unit_id in self.curation_data["removed_units"]:
762+
if unit_id in self.curation_data["removed"]:
754763
return False
755-
merged_groups = adding_group(self.curation_data["merge_unit_groups"], merge_unit_ids)
756-
self.curation_data["merge_unit_groups"] = merged_groups
764+
765+
new_merges = add_merge(self.curation_data["merges"], merge_unit_ids)
766+
self.curation_data["merges"] = new_merges
757767
if self.verbose:
758-
print(f"Merged unit group: {merge_unit_ids}")
768+
print(f"Merged unit group: {[str(u) for u in merge_unit_ids]}")
759769
return True
760770

761771
def make_manual_restore_merge(self, merge_group_indices):
762772
if not self.curation:
763773
return
764-
merge_groups_to_remove = [self.curation_data["merge_unit_groups"][merge_group_index] for merge_group_index in merge_group_indices]
765-
for merge_group in merge_groups_to_remove:
774+
for merge_index in merge_group_indices:
766775
if self.verbose:
767-
print(f"Unmerged merge group {merge_group}")
768-
self.curation_data["merge_unit_groups"].remove(merge_group)
776+
print(f"Unmerged merge group {self.curation_data['merge_unit_groups'][merge_index]['unit_ids']}")
777+
self.curation_data["merges"].pop(merge_index)
769778

770779
def get_curation_label_definitions(self):
771780
# give only label definition with exclusive

spikeinterface_gui/curation_tools.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,31 @@
1010

1111

1212
empty_curation_data = {
13+
"format_version": "2",
1314
"manual_labels": [],
14-
"merge_unit_groups": [],
15-
"removed_units": []
15+
"merges": [],
16+
"splits": [],
17+
"removes": []
1618
}
1719

18-
def adding_group(previous_groups, new_group):
20+
def add_merge(previous_merges, new_merge_unit_ids):
1921
# this is to ensure that np.str_ types are rendered as str
20-
to_merge = [np.array(new_group).tolist()]
22+
to_merge = [np.array(new_merge_unit_ids).tolist()]
2123
unchanged = []
22-
for c_prev in previous_groups:
24+
for c_prev in previous_merges:
2325
is_unaffected = True
24-
25-
for c_new in new_group:
26-
if c_new in c_prev:
26+
c_prev_unit_ids = c_prev["unit_ids"]
27+
for c_new in new_merge_unit_ids:
28+
if c_new in c_prev_unit_ids:
2729
is_unaffected = False
28-
to_merge.append(c_prev)
30+
to_merge.append(c_prev_unit_ids)
2931
break
3032

3133
if is_unaffected:
32-
unchanged.append(c_prev)
33-
new_merge_group = [sum(to_merge, [])]
34-
new_merge_group.extend(unchanged)
35-
# Ensure the unicity
36-
new_merge_group = [list(set(gp)) for gp in new_merge_group]
37-
return new_merge_group
34+
unchanged.append(c_prev_unit_ids)
35+
36+
new_merge_units = [sum(to_merge, [])]
37+
new_merge_units.extend(unchanged)
38+
# Ensure the uniqueness
39+
new_merges = [{"unit_ids": list(set(gp))} for gp in new_merge_units]
40+
return new_merges

spikeinterface_gui/curationview.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _qt_make_layout(self):
101101
def _qt_refresh(self):
102102
from .myqt import QT
103103
# Merged
104-
merged_units = self.controller.curation_data["merge_unit_groups"]
104+
merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]]
105105
self.table_merge.clear()
106106
self.table_merge.setRowCount(len(merged_units))
107107
self.table_merge.setColumnCount(1)
@@ -115,7 +115,7 @@ def _qt_refresh(self):
115115
self.table_merge.resizeColumnToContents(i)
116116

117117
## deleted
118-
removed_units = self.controller.curation_data["removed_units"]
118+
removed_units = self.controller.curation_data["removed"]
119119
self.table_delete.clear()
120120
self.table_delete.setRowCount(len(removed_units))
121121
self.table_delete.setColumnCount(1)
@@ -161,7 +161,7 @@ def _qt_on_item_selection_changed_merge(self):
161161

162162
dtype = self.controller.unit_ids.dtype
163163
ind = self.table_merge.selectedIndexes()[0].row()
164-
visible_unit_ids = self.controller.curation_data["merge_unit_groups"][ind]
164+
visible_unit_ids = [m["unit_ids"] for m in self.controller.curation_data["merges"]][ind]
165165
visible_unit_ids = [dtype.type(unit_id) for unit_id in visible_unit_ids]
166166
self.controller.set_visible_unit_ids(visible_unit_ids)
167167
self.notify_unit_visibility_changed()
@@ -170,7 +170,7 @@ def _qt_on_item_selection_changed_delete(self):
170170
if len(self.table_delete.selectedIndexes()) == 0:
171171
return
172172
ind = self.table_delete.selectedIndexes()[0].row()
173-
unit_id = self.controller.curation_data["removed_units"][ind]
173+
unit_id = self.controller.curation_data["removed"][ind]
174174
self.controller.set_all_unit_visibility_off()
175175
# convert to the correct type
176176
unit_id = self.controller.unit_ids.dtype.type(unit_id)
@@ -332,7 +332,7 @@ def _panel_make_layout(self):
332332
def _panel_refresh(self):
333333
import pandas as pd
334334
# Merged
335-
merged_units = self.controller.curation_data["merge_unit_groups"]
335+
merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]]
336336

337337
# for visualization, we make all row entries strings
338338
merged_units_str = []
@@ -345,7 +345,7 @@ def _panel_refresh(self):
345345
self.table_merge.selection = []
346346

347347
## deleted
348-
removed_units = self.controller.curation_data["removed_units"]
348+
removed_units = self.controller.curation_data["removed"]
349349
removed_units = [str(unit_id) for unit_id in removed_units]
350350
df = pd.DataFrame({"deleted_unit_id": removed_units})
351351
self.table_delete.value = df

spikeinterface_gui/launcher.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,13 +538,10 @@ def instantiate_analyzer_and_recording(analyzer_path=None, recording_path=None,
538538
try:
539539
recording = si.load(recording_path)
540540
if recording_type == "raw":
541-
from spikeinterface.preprocessing.pipeline import (
542-
get_preprocessing_dict_from_analyzer,
543-
apply_preprocessing_pipeline,
544-
)
541+
import spikeinterface.preprocessing as spre
545542

546-
preprocessing_pipeline = get_preprocessing_dict_from_analyzer(analyzer_path)
547-
recording_processed = apply_preprocessing_pipeline(
543+
preprocessing_pipeline = spre.get_preprocessing_dict_from_analyzer(analyzer_path)
544+
recording_processed = spre.apply_preprocessing_pipeline(
548545
recording,
549546
preprocessing_pipeline,
550547
)

spikeinterface_gui/main.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def run_mainwindow(
2828
address="localhost",
2929
port=0,
3030
panel_start_server_kwargs=None,
31+
panel_window_servable=True,
3132
verbose=False,
3233
):
3334
"""
@@ -75,6 +76,10 @@ def run_mainwindow(
7576
- `{'dev': True}` to enable development mode (default is False).
7677
- `{'autoreload': True}` to enable autoreload of the server when files change
7778
(default is False).
79+
panel_window_servable: bool, default: True
80+
For "web" mode only. If True, the Panel app is made servable.
81+
This is useful when embedding the GUI in another Panel app. In that case,
82+
the `panel_window_servable` should be set to False.
7883
verbose: bool, default: False
7984
If True, print some information in the console
8085
"""
@@ -130,7 +135,10 @@ def run_mainwindow(
130135
elif backend == "panel":
131136
from .backend_panel import PanelMainWindow, start_server
132137
win = PanelMainWindow(controller, layout_preset=layout_preset, layout=layout)
133-
win.main_layout.servable(title='SpikeInterface GUI')
138+
139+
if start_app or panel_window_servable:
140+
win.main_layout.servable(title='SpikeInterface GUI')
141+
134142
if start_app:
135143
panel_start_server_kwargs = panel_start_server_kwargs or {}
136144
_ = start_server(win, address=address, port=port, **panel_start_server_kwargs)

spikeinterface_gui/mergeview.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def get_table_data(self, include_deleted=False):
8080
unit_ids = list(self.controller.unit_ids)
8181
for group_ids in self.proposed_merge_unit_groups:
8282
if not include_deleted and self.controller.curation:
83-
deleted_unit_ids = self.controller.curation_data["removed_units"]
83+
deleted_unit_ids = self.controller.curation_data["removed"]
8484
if any(unit_id in deleted_unit_ids for unit_id in group_ids):
8585
continue
8686

spikeinterface_gui/tests/test_mainwindow_qt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_launcher(verbose=True):
110110
if __name__ == '__main__':
111111
if not test_folder.is_dir():
112112
setup_module()
113-
# win = test_mainwindow(start_app=True, verbose=True, curation=True)
113+
win = test_mainwindow(start_app=True, verbose=True, curation=True)
114114
# win = test_mainwindow(start_app=True, verbose=True, curation=False)
115115

116-
test_launcher(verbose=True)
116+
# test_launcher(verbose=True)

spikeinterface_gui/tests/testingtools.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"):
137137
def make_curation_dict(analyzer):
138138
unit_ids = analyzer.unit_ids.tolist()
139139
curation_dict = {
140+
"format_version": "2",
140141
"unit_ids": unit_ids,
141142
"label_definitions": {
142143
"quality":{
@@ -153,8 +154,8 @@ def make_curation_dict(analyzer):
153154
{'unit_id': unit_ids[2], "putative_type": ["exitatory"]},
154155
{'unit_id': unit_ids[3], "quality": ["noise"], "putative_type": ["inhibitory"]},
155156
],
156-
"merge_unit_groups": [unit_ids[:3], unit_ids[3:5]],
157-
"removed_units": unit_ids[5:8],
157+
"merges": [{"unit_ids": unit_ids[:3]}, {"unit_ids": unit_ids[3:5]}],
158+
"removed": unit_ids[5:8],
158159
}
159160
return curation_dict
160161

spikeinterface_gui/waveformview.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def get_spike_waveform(self, ind):
8181
trace_source='preprocessed',
8282
segment_index=seg_num,
8383
start_frame=peak_ind - nbefore, end_frame=peak_ind + nafter,
84+
return_in_uV=self.controller.return_in_uV
8485
)
8586
return wf, width
8687

0 commit comments

Comments
 (0)