Skip to content

Commit 4561493

Browse files
committed
option to load manually curated spikesorting datasets
1 parent 9dc5ff3 commit 4561493

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

brainbox/io/one.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import re
66
import os
77
from pathlib import Path
8+
from collections import defaultdict
89

910
import numpy as np
1011
import pandas as pd
1112
from scipy.interpolate import interp1d
1213
import matplotlib.pyplot as plt
1314

1415
from one.api import ONE, One
15-
from one.alf.path import get_alf_path, full_path_parts
16+
from one.alf.path import get_alf_path, full_path_parts, filename_parts
1617
from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound
1718
from one.alf import cache
1819
import one.alf.io as alfio
@@ -193,9 +194,9 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
193194
for pname in pnames:
194195
probe_collection = _get_spike_sorting_collection(collections, pname)
195196
spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes',
196-
attribute=spike_attributes)
197+
attribute=spike_attributes, namespace='')
197198
clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters',
198-
attribute=cluster_attributes)
199+
attribute=cluster_attributes, namespace='')
199200
if return_channels:
200201
channels = _load_channels_locations_from_disk(
201202
eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions)
@@ -1035,7 +1036,31 @@ def load_channels(self, **kwargs):
10351036
self.histology = 'alf'
10361037
return Bunch(channels)
10371038

1038-
def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, **kwargs):
1039+
@staticmethod
1040+
def filter_files_by_namespace(all_files, namespace):
1041+
1042+
# Create dict for each file with available namespaces, no namespce is stored under the key None
1043+
namespace_files = defaultdict(dict)
1044+
available_namespaces = []
1045+
for file in all_files:
1046+
fparts = filename_parts(file.name, as_dict=True)
1047+
fname = f"{fparts['object']}.{fparts['attribute']}"
1048+
nspace = fparts['namespace']
1049+
available_namespaces.append(nspace)
1050+
namespace_files[fname][nspace] = file
1051+
1052+
if namespace not in set(available_namespaces):
1053+
_logger.info(f'Could not find manual curation results for {namespace}, returning default'
1054+
f' non manually curated spikesorting data')
1055+
1056+
# Return the files with the chosen namespace.
1057+
files = [f.get(namespace, f.get(None, None)) for f in namespace_files.values()]
1058+
# remove any None files
1059+
files = [f for f in files if f]
1060+
return files
1061+
1062+
def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False,
1063+
namespace=None, **kwargs):
10391064
"""
10401065
Loads spikes, clusters and channels
10411066
@@ -1053,6 +1078,8 @@ def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_ve
10531078
:param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
10541079
:param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
10551080
:param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
1081+
:param namespace: None, if given will load the manually curated spikesorting with the given namespace,
1082+
e.g to load '_av_.clusters.depths use namespace='av'
10561083
:param kwargs: additional arguments to be passed to one.api.One.load_object
10571084
:return:
10581085
"""
@@ -1061,13 +1088,21 @@ def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_ve
10611088
self.files = {}
10621089
self.spike_sorter = spike_sorter
10631090
self.revision = revision
1091+
1092+
if good_units and namespace is not None:
1093+
_logger.info('Good units table does not exist for manually curated spike sorting. Pass in namespace with'
1094+
'good_units=False and filter the spikes post hoc by the good clusters.')
1095+
return [None] * 3
10641096
objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None
10651097
self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs)
10661098
channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs)
1099+
self.files['clusters'] = self.filter_files_by_namespace(self.files['clusters'], namespace)
10671100
clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards)
1101+
10681102
if good_units:
10691103
spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards)
10701104
else:
1105+
self.files['spikes'] = self.filter_files_by_namespace(self.files['spikes'], namespace)
10711106
spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards)
10721107
if enforce_version:
10731108
self._assert_version_consistency()

0 commit comments

Comments
 (0)