1212import spikeinterface .qualitymetrics
1313from spikeinterface .core .sorting_tools import spike_vector_to_indices
1414from spikeinterface .core .core_tools import check_json
15+ from spikeinterface .curation import validate_curation_dict
1516from spikeinterface .widgets .utils import make_units_table_from_analyzer
1617
17- from .curation_tools import adding_group , default_label_definitions , empty_curation_data
18+ from .curation_tools import add_merge , default_label_definitions , empty_curation_data
1819
1920spike_dtype = [('sample_index' , 'int64' ), ('unit_index' , 'int64' ),
2021 ('channel_index' , 'int64' ), ('segment_index' , 'int64' ),
2627 color_mode = 'color_by_unit' ,
2728)
2829
29- # TODO handle return_scaled
3030from spikeinterface .widgets .sorting_summary import _default_displayed_unit_properties
3131
3232
@@ -53,7 +53,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
5353 self .analyzer = analyzer
5454 assert self .analyzer .get_extension ("random_spikes" ) is not None
5555
56- self .return_scaled = True
56+ self .return_in_uV = self . analyzer . return_in_uV
5757 self .save_on_compute = save_on_compute
5858
5959 self .verbose = verbose
@@ -328,7 +328,16 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
328328 if curation_data is None :
329329 self .curation_data = empty_curation_data .copy ()
330330 else :
331- self .curation_data = curation_data
331+ # validate the curation data
332+ format_version = curation_data .get ("format_version" , None )
333+ # assume version 2 if not present
334+ if format_version is None :
335+ raise ValueError ("Curation data format version is missing and is required in the curation data." )
336+ try :
337+ validate_curation_dict (curation_data )
338+ self .curation_data = curation_data
339+ except Exception as e :
340+ raise ValueError (f"Invalid curation data.\n Error: { e } " )
332341
333342 self .has_default_quality_labels = False
334343 if "label_definitions" not in self .curation_data :
@@ -547,7 +556,7 @@ def get_traces(self, trace_source='preprocessed', **kargs):
547556 elif trace_source == 'raw' :
548557 raise NotImplemented
549558 # TODO get with parent recording the non process recording
550- kargs ['return_scaled ' ] = self .return_scaled
559+ kargs ['return_in_uV ' ] = self .return_in_uV
551560 traces = rec .get_traces (** kargs )
552561 # put in cache for next call
553562 self ._traces_cached [cache_key ] = traces
@@ -678,7 +687,7 @@ def curation_can_be_saved(self):
678687
679688 def construct_final_curation (self ):
680689 d = dict ()
681- d ["format_version" ] = "1 "
690+ d ["format_version" ] = "2 "
682691 d ["unit_ids" ] = self .unit_ids .tolist ()
683692 d .update (self .curation_data .copy ())
684693 return d
@@ -709,14 +718,14 @@ def make_manual_delete_if_possible(self, removed_unit_ids):
709718 if not self .curation :
710719 return
711720
712- all_merged_units = sum (self .curation_data ["merge_unit_groups" ], [])
721+ all_merged_units = sum ([ m [ "unit_ids" ] for m in self .curation_data ["merges" ] ], [])
713722 for unit_id in removed_unit_ids :
714- if unit_id in self .curation_data ["removed_units " ]:
723+ if unit_id in self .curation_data ["removed " ]:
715724 continue
716725 # TODO: check if unit is already in a merge group
717726 if unit_id in all_merged_units :
718727 continue
719- self .curation_data ["removed_units " ].append (unit_id )
728+ self .curation_data ["removed " ].append (unit_id )
720729 if self .verbose :
721730 print (f"Unit { unit_id } is removed from the curation data" )
722731
@@ -728,10 +737,10 @@ def make_manual_restore(self, restore_unit_ids):
728737 return
729738
730739 for unit_id in restore_unit_ids :
731- if unit_id in self .curation_data ["removed_units " ]:
740+ if unit_id in self .curation_data ["removed " ]:
732741 if self .verbose :
733742 print (f"Unit { unit_id } is restored from the curation data" )
734- self .curation_data ["removed_units " ].remove (unit_id )
743+ self .curation_data ["removed " ].remove (unit_id )
735744
736745 def make_manual_merge_if_possible (self , merge_unit_ids ):
737746 """
@@ -750,22 +759,22 @@ def make_manual_merge_if_possible(self, merge_unit_ids):
750759 return False
751760
752761 for unit_id in merge_unit_ids :
753- if unit_id in self .curation_data ["removed_units " ]:
762+ if unit_id in self .curation_data ["removed " ]:
754763 return False
755- merged_groups = adding_group (self .curation_data ["merge_unit_groups" ], merge_unit_ids )
756- self .curation_data ["merge_unit_groups" ] = merged_groups
764+
765+ new_merges = add_merge (self .curation_data ["merges" ], merge_unit_ids )
766+ self .curation_data ["merges" ] = new_merges
757767 if self .verbose :
758- print (f"Merged unit group: { merge_unit_ids } " )
768+ print (f"Merged unit group: { [ str ( u ) for u in merge_unit_ids ] } " )
759769 return True
760770
761771 def make_manual_restore_merge (self , merge_group_indices ):
762772 if not self .curation :
763773 return
764- merge_groups_to_remove = [self .curation_data ["merge_unit_groups" ][merge_group_index ] for merge_group_index in merge_group_indices ]
765- for merge_group in merge_groups_to_remove :
774+ for merge_index in merge_group_indices :
766775 if self .verbose :
767- print (f"Unmerged merge group { merge_group } " )
768- self .curation_data ["merge_unit_groups " ].remove ( merge_group )
776+ print (f"Unmerged merge group { self . curation_data [ 'merge_unit_groups' ][ merge_index ][ 'unit_ids' ] } " )
777+ self .curation_data ["merges " ].pop ( merge_index )
769778
770779 def get_curation_label_definitions (self ):
771780 # give only label definition with exclusive
0 commit comments