Skip to content

Commit 20c4de3

Browse files
authored
Merge pull request #954 from int-brain-lab/manual_curation_ssloader
Manual curation ssloader
2 parents 6e63c9e + 61d5bad commit 20c4de3

File tree

4 files changed

+54
-9
lines changed

4 files changed

+54
-9
lines changed

brainbox/io/one.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import re
66
import os
77
from pathlib import Path
8+
from collections import defaultdict
89

910
import numpy as np
1011
import pandas as pd
1112
from scipy.interpolate import interp1d
1213
import matplotlib.pyplot as plt
1314

1415
from one.api import ONE, One
15-
from one.alf.path import get_alf_path, full_path_parts
16+
from one.alf.path import get_alf_path, full_path_parts, filename_parts
1617
from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound
1718
from one.alf import cache
1819
import one.alf.io as alfio
@@ -193,9 +194,9 @@ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_ch
193194
for pname in pnames:
194195
probe_collection = _get_spike_sorting_collection(collections, pname)
195196
spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes',
196-
attribute=spike_attributes)
197+
attribute=spike_attributes, namespace='')
197198
clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters',
198-
attribute=cluster_attributes)
199+
attribute=cluster_attributes, namespace='')
199200
if return_channels:
200201
channels = _load_channels_locations_from_disk(
201202
eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions)
@@ -1035,7 +1036,31 @@ def load_channels(self, **kwargs):
10351036
self.histology = 'alf'
10361037
return Bunch(channels)
10371038

1038-
def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, **kwargs):
1039+
@staticmethod
1040+
def filter_files_by_namespace(all_files, namespace):
1041+
1042+
# Create dict for each file with available namespaces, no namespce is stored under the key None
1043+
namespace_files = defaultdict(dict)
1044+
available_namespaces = []
1045+
for file in all_files:
1046+
fparts = filename_parts(file.name, as_dict=True)
1047+
fname = f"{fparts['object']}.{fparts['attribute']}"
1048+
nspace = fparts['namespace']
1049+
available_namespaces.append(nspace)
1050+
namespace_files[fname][nspace] = file
1051+
1052+
if namespace not in set(available_namespaces):
1053+
_logger.info(f'Could not find manual curation results for {namespace}, returning default'
1054+
f' non manually curated spikesorting data')
1055+
1056+
# Return the files with the chosen namespace.
1057+
files = [f.get(namespace, f.get(None, None)) for f in namespace_files.values()]
1058+
# remove any None files
1059+
files = [f for f in files if f]
1060+
return files
1061+
1062+
def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False,
1063+
namespace=None, **kwargs):
10391064
"""
10401065
Loads spikes, clusters and channels
10411066
@@ -1053,6 +1078,8 @@ def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_ve
10531078
:param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
10541079
:param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
10551080
:param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
1081+
:param namespace: None, if given will load the manually curated spikesorting with the given namespace,
1082+
e.g to load '_av_.clusters.depths use namespace='av'
10561083
:param kwargs: additional arguments to be passed to one.api.One.load_object
10571084
:return:
10581085
"""
@@ -1061,13 +1088,21 @@ def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_ve
10611088
self.files = {}
10621089
self.spike_sorter = spike_sorter
10631090
self.revision = revision
1091+
1092+
if good_units and namespace is not None:
1093+
_logger.info('Good units table does not exist for manually curated spike sorting. Pass in namespace with'
1094+
'good_units=False and filter the spikes post hoc by the good clusters.')
1095+
return [None] * 3
10641096
objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None
10651097
self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs)
10661098
channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs)
1099+
self.files['clusters'] = self.filter_files_by_namespace(self.files['clusters'], namespace)
10671100
clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards)
1101+
10681102
if good_units:
10691103
spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards)
10701104
else:
1105+
self.files['spikes'] = self.filter_files_by_namespace(self.files['spikes'], namespace)
10711106
spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards)
10721107
if enforce_version:
10731108
self._assert_version_consistency()

ibllib/io/extractors/mesoscope.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,8 @@ def _extract(self, sync=None, chmap=None, device_collection='raw_imaging_data',
673673
assert len(fov_time_shifts) == self.n_FOVs, f'unexpected number of FOVs for {collection}'
674674
ts = frame_times[np.logical_and(frame_times >= tmin, frame_times <= tmax)]
675675
assert ts.size >= imaging_data[
676-
'times_scanImage'].size, f"fewer DAQ timestamps for {collection} than expected: DAQ/frames = {ts.size}/{imaging_data['times_scanImage'].size}"
676+
'times_scanImage'].size, (f"fewer DAQ timestamps for {collection} than expected: "
677+
f"DAQ/frames = {ts.size}/{imaging_data['times_scanImage'].size}")
677678
if ts.size > imaging_data['times_scanImage'].size:
678679
_logger.warning(
679680
'More DAQ frame times detected for %s than were found in the raw image data.\n'

ibllib/pipes/video_tasks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ def run_qc(self, camera_data=None, update=True):
252252
if camera_data is None:
253253
camera_data, _ = self.extract_camera(save=False)
254254
qc = CameraQC(
255-
self.session_path, 'left', sync_type='bpod', sync_collection=self.collection, one=self.one)
255+
self.session_path, 'left', sync_type='bpod', sync_collection=self.collection, one=self.one,
256+
protocol=self.protocol)
256257
qc.run(update=update)
257258
return qc
258259

ibllib/qc/camera.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(self, session_path_or_eid, camera, **kwargs):
135135
self.n_samples = kwargs.pop('n_samples', 100)
136136
self.sync_collection = kwargs.pop('sync_collection', None)
137137
self.sync = kwargs.pop('sync_type', None)
138+
self.protocol = kwargs.pop('protocol', None)
138139
super().__init__(session_path_or_eid, **kwargs)
139140

140141
# Data
@@ -163,7 +164,10 @@ def __init__(self, session_path_or_eid, camera, **kwargs):
163164
self.outcome = spec.QC.NOT_SET
164165

165166
# Specify any checks to remove
166-
self.checks_to_remove = []
167+
if self.protocol is not None and 'habituation' in self.protocol:
168+
self.checks_to_remove = ['check_wheel_alignment']
169+
else:
170+
self.checks_to_remove = []
167171
self._type = None
168172

169173
@property
@@ -271,8 +275,12 @@ def load_data(self, extract_times: bool = False, load_video: bool = True) -> Non
271275
else:
272276
raise NotImplementedError(f'Unknown namespace "{ns}"')
273277
else:
274-
wheel_data = training_wheel.get_wheel_position(
275-
self.session_path, task_collection=task_collection)
278+
if self.protocol is not None and 'habituation' in self.protocol:
279+
wheel_data = training_wheel.get_wheel_position(
280+
self.session_path, task_collection=task_collection)
281+
else:
282+
wheel_data = [None, None]
283+
276284
self.data['wheel'] = Bunch(zip(wheel_keys, wheel_data))
277285

278286
# Find short period of wheel motion for motion correlation.

0 commit comments

Comments
 (0)