55import re
66import os
77from pathlib import Path
8+ from collections import defaultdict
89
910import numpy as np
1011import pandas as pd
1112from scipy .interpolate import interp1d
1213import matplotlib .pyplot as plt
1314
1415from one .api import ONE , One
15- from one .alf .path import get_alf_path , full_path_parts
16+ from one .alf .path import get_alf_path , full_path_parts , filename_parts
1617from one .alf .exceptions import ALFObjectNotFound , ALFMultipleCollectionsFound
1718from one .alf import cache
1819import one .alf .io as alfio
@@ -193,9 +194,9 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
193194 for pname in pnames :
194195 probe_collection = _get_spike_sorting_collection (collections , pname )
195196 spikes [pname ] = one .load_object (eid , collection = probe_collection , obj = 'spikes' ,
196- attribute = spike_attributes )
197+ attribute = spike_attributes , namespace = '' )
197198 clusters [pname ] = one .load_object (eid , collection = probe_collection , obj = 'clusters' ,
198- attribute = cluster_attributes )
199+ attribute = cluster_attributes , namespace = '' )
199200 if return_channels :
200201 channels = _load_channels_locations_from_disk (
201202 eid , collection = collection , one = one , revision = revision , brain_regions = brain_regions )
@@ -1035,7 +1036,31 @@ def load_channels(self, **kwargs):
10351036 self .histology = 'alf'
10361037 return Bunch (channels )
10371038
1038- def load_spike_sorting (self , spike_sorter = 'iblsorter' , revision = None , enforce_version = False , good_units = False , ** kwargs ):
1039+ @staticmethod
1040+ def filter_files_by_namespace (all_files , namespace ):
1041+
1042+ # Create dict for each file with available namespaces, no namespce is stored under the key None
1043+ namespace_files = defaultdict (dict )
1044+ available_namespaces = []
1045+ for file in all_files :
1046+ fparts = filename_parts (file .name , as_dict = True )
1047+ fname = f"{ fparts ['object' ]} .{ fparts ['attribute' ]} "
1048+ nspace = fparts ['namespace' ]
1049+ available_namespaces .append (nspace )
1050+ namespace_files [fname ][nspace ] = file
1051+
1052+ if namespace not in set (available_namespaces ):
1053+ _logger .info (f'Could not find manual curation results for { namespace } , returning default'
1054+ f' non manually curated spikesorting data' )
1055+
1056+ # Return the files with the chosen namespace.
1057+ files = [f .get (namespace , f .get (None , None )) for f in namespace_files .values ()]
1058+ # remove any None files
1059+ files = [f for f in files if f ]
1060+ return files
1061+
1062+ def load_spike_sorting (self , spike_sorter = 'iblsorter' , revision = None , enforce_version = False , good_units = False ,
1063+ namespace = None , ** kwargs ):
10391064 """
10401065 Loads spikes, clusters and channels
10411066
@@ -1053,6 +1078,8 @@ def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_ve
10531078 :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
10541079 :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
10551080 :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
1081+ :param namespace: None, if given will load the manually curated spikesorting with the given namespace,
1082+ e.g to load '_av_.clusters.depths use namespace='av'
10561083 :param kwargs: additional arguments to be passed to one.api.One.load_object
10571084 :return:
10581085 """
@@ -1061,13 +1088,21 @@ def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_ve
10611088 self .files = {}
10621089 self .spike_sorter = spike_sorter
10631090 self .revision = revision
1091+
1092+ if good_units and namespace is not None :
1093+ _logger .info ('Good units table does not exist for manually curated spike sorting. Pass in namespace with'
1094+ 'good_units=False and filter the spikes post hoc by the good clusters.' )
1095+ return [None ] * 3
10641096 objects = ['passingSpikes' , 'clusters' , 'channels' ] if good_units else None
10651097 self .download_spike_sorting (spike_sorter = spike_sorter , revision = revision , objects = objects , ** kwargs )
10661098 channels = self .load_channels (spike_sorter = spike_sorter , revision = revision , ** kwargs )
1099+ self .files ['clusters' ] = self .filter_files_by_namespace (self .files ['clusters' ], namespace )
10671100 clusters = self ._load_object (self .files ['clusters' ], wildcards = self .one .wildcards )
1101+
10681102 if good_units :
10691103 spikes = self ._load_object (self .files ['passingSpikes' ], wildcards = self .one .wildcards )
10701104 else :
1105+ self .files ['spikes' ] = self .filter_files_by_namespace (self .files ['spikes' ], namespace )
10711106 spikes = self ._load_object (self .files ['spikes' ], wildcards = self .one .wildcards )
10721107 if enforce_version :
10731108 self ._assert_version_consistency ()
0 commit comments