Link waveforms of same sort group in UnitWaveformFeaturesGroup#1501
Link waveforms of same sort group in UnitWaveformFeaturesGroup#1501samuelbray32 wants to merge 2 commits intomasterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds functionality to link waveform features from the same sort group across different epochs/intervals, addressing issue #1449. The implementation allows users to combine spike sorting data from separate processing runs (e.g., run and sleep epochs) by concatenating spike times and waveform features for corresponding sort groups.
Changes:
- Adds
LinkedSortspart table to track which SpikeSortingMerge IDs should be concatenated - Implements
fetch_datamethod to retrieve and merge linked sorts - Extends
create_groupto accept optional linked merge IDs during group creation - Updates
fetch_spike_datato use the newfetch_datamethod
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raise ValueError( | ||
| f"Linked SpikeSortingMerge ID {merge_id} not found in " | ||
| + "UnitFeatures table" | ||
| + f" for group {key['waveform_features_group_name']}" |
There was a problem hiding this comment.
The error message construction on lines 141-143 uses key['waveform_features_group_name'] but this key might not exist in the key parameter, as the method accepts key: dict = dict() and it's not guaranteed that the key contains this field after restriction. Consider using group_key['waveform_features_group_name'] instead, which is fetched on line 107.
| + f" for group {key['waveform_features_group_name']}" | |
| + f" for group {group_key['waveform_features_group_name']}" |
| df = pd.DataFrame( | ||
| { | ||
| "merge_id": merge_ids, | ||
| "spike_times": spike_times, | ||
| "waveform_features": spike_waveform_features, | ||
| } | ||
| ) | ||
| df.set_index("merge_id", inplace=True) |
There was a problem hiding this comment.
Using a pandas DataFrame to organize the spike times and waveform features, only to convert them back to lists, adds unnecessary overhead. Consider using a dictionary instead, which would be more efficient: data_dict = {k['spikesorting_merge_id']: (st, wf) for k, st, wf in zip(waveform_keys, spike_times, spike_waveform_features)}. This would avoid the DataFrame construction and make the code more straightforward.
| {**key, **group_key}, | ||
| skip_duplicates=True, | ||
| ) | ||
| if linked_ids is not None: |
There was a problem hiding this comment.
The code doesn't handle the case where linked_ids could be an empty list. If the user passes linked_ids=[], the code will still try to call insert_linked_ids with an empty list, which will iterate over nothing but still create a transaction. While this is harmless, it's inefficient. Consider checking if linked_ids instead of if linked_ids is not None to skip processing for both None and empty lists.
| if linked_ids is not None: | |
| if linked_ids: |
| def insert_linked_ids( | ||
| self, key: dict, linked_merge_ids_list: list[list[str]] | ||
| ): | ||
| """Insert linked SpikeSortingMerge IDs for multiple waveform features groups""" | ||
| with self.LinkedSorts.connection.transaction: | ||
| for linked_merge_ids in linked_merge_ids_list: | ||
| self.insert1_linked_ids(key, linked_merge_ids) |
There was a problem hiding this comment.
The code doesn't validate that linked_merge_ids contains at least 2 elements. Linking a single merge_id to itself doesn't make sense and could lead to confusion. Consider adding validation to ensure each list in linked_ids contains at least 2 merge_ids, and raise a ValueError with a clear message if not.
| def insert_linked_ids( | ||
| self, key: dict, linked_merge_ids_list: list[list[str]] | ||
| ): | ||
| """Insert linked SpikeSortingMerge IDs for multiple waveform features groups""" | ||
| with self.LinkedSorts.connection.transaction: | ||
| for linked_merge_ids in linked_merge_ids_list: | ||
| self.insert1_linked_ids(key, linked_merge_ids) |
There was a problem hiding this comment.
The code doesn't prevent the same merge_id from appearing in multiple different linked_id groups. If a merge_id appears in two different linked_id lists, it will be included twice in the final output (once for each link), leading to duplicated data. Consider adding validation to ensure that each merge_id appears in at most one linked_id group, and raise a ValueError if duplicates are detected across different groups.
| if merge_id in df.index: | ||
| times_list.append(df.at[merge_id, "spike_times"]) | ||
| features_list.append(df.at[merge_id, "waveform_features"]) | ||
| managed_ids.add(merge_id) |
There was a problem hiding this comment.
When checking if a merge_id exists in the DataFrame, the code uses if merge_id in df.index, but this returns True even if there are multiple occurrences of the merge_id in the index (in case of duplicates). Combined with the use of df.at[merge_id, ...], which returns the first match, this could lead to subtle bugs where only one occurrence is processed while others are silently ignored. Consider using a dictionary instead of a DataFrame with potential duplicate indices.
| merged_spike_times = [] | ||
| merged_waveform_features = [] | ||
| managed_ids = set() | ||
| for linked_ids in linked_id_list: | ||
| times_list = [] | ||
| features_list = [] | ||
| for merge_id in linked_ids: | ||
| if merge_id in df.index: | ||
| times_list.append(df.at[merge_id, "spike_times"]) | ||
| features_list.append(df.at[merge_id, "waveform_features"]) | ||
| managed_ids.add(merge_id) | ||
| else: | ||
| raise ValueError( | ||
| f"Linked SpikeSortingMerge ID {merge_id} not found in " | ||
| + "UnitFeatures table" | ||
| + f" for group {key['waveform_features_group_name']}" | ||
| ) | ||
| if times_list: | ||
| merged_times_i = np.concatenate(times_list) | ||
| merged_features_i = np.concatenate(features_list) | ||
| ind_sort = np.argsort(merged_times_i) | ||
| merged_spike_times.append(merged_times_i[ind_sort]) | ||
| merged_waveform_features.append(merged_features_i[ind_sort]) | ||
|
|
||
| # add any remaining unlinked units | ||
| for merge_id in df.index: | ||
| if merge_id not in managed_ids: | ||
| merged_spike_times.append(df.at[merge_id, "spike_times"]) | ||
| merged_waveform_features.append( | ||
| df.at[merge_id, "waveform_features"] | ||
| ) | ||
| return merged_spike_times, merged_waveform_features |
There was a problem hiding this comment.
The method doesn't preserve the original order of units when concatenating linked and unlinked sorts. The order will be: first all linked sorts (in the order they appear in LinkedSorts), then all remaining unlinked sorts. This could be surprising to users and may break downstream code that expects a specific ordering. Consider documenting this behavior or preserving the original order of units.
| keys: list[dict], | ||
| linked_ids: list[list[str]] = None, | ||
| ): | ||
| """Create a group of waveform features for a given session""" |
There was a problem hiding this comment.
The create_group method now accepts a linked_ids parameter but doesn't document it. The docstring should be updated to document this new parameter, explaining that it's an optional list of lists where each inner list contains SpikeSortingMerge IDs that should be linked (concatenated) together.
| """Create a group of waveform features for a given session""" | |
| """ | |
| Create a group of waveform features for a given session. | |
| Parameters | |
| ---------- | |
| nwb_file_name : str | |
| Name of the NWB file corresponding to the session. | |
| group_name : str | |
| Name to assign to this waveform features group. | |
| keys : list of dict | |
| List of primary keys for `UnitWaveformFeatures` entries to include | |
| in this group. | |
| linked_ids : list of list of str, optional | |
| Optional list of lists, where each inner list contains | |
| SpikeSortingMerge IDs that should be linked (concatenated) | |
| together into a single group. If None, no SpikeSortingMerge | |
| entries are linked. | |
| """ |
| nwb_file_name: str, | ||
| group_name: str, | ||
| keys: list[dict], | ||
| linked_ids: list[list[str]] = None, |
There was a problem hiding this comment.
The parameter name linked_ids in create_group is somewhat ambiguous - it could be confused with the link_id field in the LinkedSorts table. Consider renaming it to linked_merge_ids to be more explicit and consistent with the naming in insert1_linked_ids and the field name in the LinkedSorts table.
| if linked_ids is not None: | ||
| self.insert_linked_ids(group_key, linked_ids) | ||
|
|
||
| def fetch_data(self, key: dict = dict()): |
There was a problem hiding this comment.
The fetch_data method is defined as an instance method (taking self as the first parameter) but is being called as a class method in line 696. This will cause an error since the method needs to be called on an instance, not the class itself. Either change this to a class method by adding the @classmethod decorator and changing self to cls, or call it on an instance.
| def fetch_data(self, key: dict = dict()): | |
| @classmethod | |
| def fetch_data(cls, key: dict = dict()): |
Description
Resolves #1449
UnitWaveformFeaturesGroup.LinkedSortsUnitWaveformFeaturesGroup.fetch_datamerge_idChecklist:
CITATION.cffaltersnippet for release notes.CHANGELOG.mdwith PR number and description.