Skip to content

Commit b1dd4e1

Browse files
committed
Merge branch 'release/2.40.0'
2 parents c4418cb + b235446 commit b1dd4e1

33 files changed

+454
-220
lines changed

brainbox/io/one.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -866,13 +866,21 @@ def _get_attributes(dataset_types):
866866
waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes))
867867
return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes}
868868

869-
def _get_spike_sorting_collection(self, spike_sorter='pykilosort'):
869+
def _get_spike_sorting_collection(self, spike_sorter=None):
870870
"""
871871
Filters a list or array of collections to get the relevant spike sorting dataset
872872
if there is a pykilosort, load it
873873
"""
874-
collection = next(filter(lambda c: c == f'alf/{self.pname}/{spike_sorter}', self.collections), None)
875-
# otherwise, prefers the shortest
874+
for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']):
875+
if sorter is None:
876+
continue
877+
if sorter == "":
878+
collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None)
879+
else:
880+
collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None)
881+
if collection is not None:
882+
return collection
883+
# if none is found amongst the defaults, prefers the shortest
876884
collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None)
877885
_logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}")
878886
return collection
@@ -982,14 +990,13 @@ def download_raw_waveforms(self, **kwargs):
982990
"""
983991
_logger.debug(f"loading waveforms from {self.collection}")
984992
return self.one.load_object(
985-
self.eid, "waveforms",
986-
attribute=["traces", "templates", "table", "channels"],
993+
id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"],
987994
collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs
988995
)
989996

990997
def raw_waveforms(self, **kwargs):
991998
wf_paths = self.download_raw_waveforms(**kwargs)
992-
return WaveformsLoader(wf_paths[0].parent, wfs_dtype=np.float16)
999+
return WaveformsLoader(wf_paths[0].parent)
9931000

9941001
def load_channels(self, **kwargs):
9951002
"""
@@ -1022,7 +1029,7 @@ def load_channels(self, **kwargs):
10221029
self.histology = 'alf'
10231030
return Bunch(channels)
10241031

1025-
def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs):
1032+
def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, **kwargs):
10261033
"""
10271034
Loads spikes, clusters and channels
10281035

brainbox/io/spikeglx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(self, pid, one, typ='ap', cache_folder=None, remove_cached=False):
128128
self.file_chunks = self.one.load_dataset(self.eid, f'*.{typ}.ch', collection=f"*{self.pname}")
129129
meta_file = self.one.load_dataset(self.eid, f'*.{typ}.meta', collection=f"*{self.pname}")
130130
cbin_rec = self.one.list_datasets(self.eid, collection=f"*{self.pname}", filename=f'*{typ}.*bin', details=True)
131+
cbin_rec.index = cbin_rec.index.map(lambda x: (self.eid, x))
131132
self.url_cbin = self.one.record2url(cbin_rec)[0]
132133
with open(self.file_chunks, 'r') as f:
133134
self.chunks = json.load(f)

ibllib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import warnings
44

5-
__version__ = '2.39.1'
5+
__version__ = '2.40.0'
66
warnings.filterwarnings('always', category=DeprecationWarning, module='ibllib')
77

88
# if this becomes a full-blown library we should let the logging configuration to the discretion of the dev

ibllib/ephys/sync_probes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def sync(ses_path, **kwargs):
4747
return version3B(ses_path, **kwargs)
4848

4949

50-
def version3A(ses_path, display=True, type='smooth', tol=2.1):
50+
def version3A(ses_path, display=True, type='smooth', tol=2.1, probe_names=None):
5151
"""
5252
From a session path with _spikeglx_sync arrays extracted, locate ephys files for 3A and
5353
outputs one sync.timestamps.probeN.npy file per acquired probe. By convention the reference

ibllib/io/extractors/biased_trials.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class TrialsTableBiased(BaseBpodTrialsExtractor):
9797
save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
9898
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None)
9999
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
100-
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement')
100+
'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement')
101101

102102
def _extract(self, extractor_classes=None, **kwargs):
103103
extractor_classes = extractor_classes or []
@@ -125,7 +125,7 @@ class TrialsTableEphys(BaseBpodTrialsExtractor):
125125
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None,
126126
None, None, None, '_ibl_trials.quiescencePeriod.npy')
127127
var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals',
128-
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement',
128+
'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement',
129129
'phase', 'position', 'quiescence')
130130

131131
def _extract(self, extractor_classes=None, **kwargs):
@@ -152,12 +152,12 @@ class BiasedTrials(BaseBpodTrialsExtractor):
152152
save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None,
153153
'_ibl_trials.stimOffTrigger_times.npy', None, None, '_ibl_trials.table.pqt',
154154
'_ibl_trials.stimOff_times.npy', None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy',
155-
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, '_ibl_trials.included.npy',
156-
None, None, '_ibl_trials.quiescencePeriod.npy')
155+
'_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None,
156+
'_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy')
157157
var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
158158
'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
159-
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
160-
'phase', 'position', 'quiescence')
159+
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement',
160+
'included', 'phase', 'position', 'quiescence')
161161

162162
def _extract(self, extractor_classes=None, **kwargs) -> dict:
163163
extractor_classes = extractor_classes or []
@@ -182,8 +182,8 @@ class EphysTrials(BaseBpodTrialsExtractor):
182182
'_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy')
183183
var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times',
184184
'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position',
185-
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included',
186-
'phase', 'position', 'quiescence')
185+
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement',
186+
'included', 'phase', 'position', 'quiescence')
187187

188188
def _extract(self, extractor_classes=None, **kwargs) -> dict:
189189
extractor_classes = extractor_classes or []

ibllib/io/extractors/bpod_trials.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,14 @@
33
This module will extract the Bpod trials and wheel data based on the task protocol,
44
i.e. habituation, training or biased.
55
"""
6-
import logging
76
import importlib
87

9-
from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor
8+
from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor, BaseExtractor
109
from ibllib.io.extractors.habituation_trials import HabituationTrials
1110
from ibllib.io.extractors.training_trials import TrainingTrials
1211
from ibllib.io.extractors.biased_trials import BiasedTrials, EphysTrials
1312
from ibllib.io.extractors.base import BaseBpodTrialsExtractor
1413

15-
_logger = logging.getLogger(__name__)
16-
1714

1815
def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavior_data') -> BaseBpodTrialsExtractor:
1916
"""
@@ -39,20 +36,25 @@ def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavio
3936
'BiasedTrials': BiasedTrials,
4037
'EphysTrials': EphysTrials
4138
}
39+
4240
if protocol:
43-
class_name = protocol2extractor(protocol)
41+
extractor_class_name = protocol2extractor(protocol)
4442
else:
45-
class_name = get_bpod_extractor_class(session_path, task_collection=task_collection)
46-
if class_name in builtins:
47-
return builtins[class_name](session_path)
43+
extractor_class_name = get_bpod_extractor_class(session_path, task_collection=task_collection)
44+
if extractor_class_name in builtins:
45+
return builtins[extractor_class_name](session_path)
4846

4947
# look if there are custom extractor types in the personal projects repo
50-
if not class_name.startswith('projects.'):
51-
class_name = 'projects.' + class_name
52-
module, class_name = class_name.rsplit('.', 1)
48+
if not extractor_class_name.startswith('projects.'):
49+
extractor_class_name = 'projects.' + extractor_class_name
50+
module, extractor_class_name = extractor_class_name.rsplit('.', 1)
5351
mdl = importlib.import_module(module)
54-
extractor_class = getattr(mdl, class_name, None)
52+
extractor_class = getattr(mdl, extractor_class_name, None)
5553
if extractor_class:
56-
return extractor_class(session_path)
54+
my_extractor = extractor_class(session_path)
55+
if not isinstance(my_extractor, BaseExtractor):
56+
raise ValueError(
57+
f"{my_extractor} should be an Extractor class inheriting from ibllib.io.extractors.base.BaseExtractor")
58+
return my_extractor
5759
else:
58-
raise ValueError(f'extractor {class_name} not found')
60+
raise ValueError(f'extractor {extractor_class_name} not found')

ibllib/io/extractors/ephys_fpga.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,13 +585,13 @@ class FpgaTrials(extractors_base.BaseExtractor):
585585
'_ibl_trials.stimOff_times.npy', None, None, None, '_ibl_trials.quiescencePeriod.npy',
586586
'_ibl_trials.table.pqt', '_ibl_wheel.timestamps.npy',
587587
'_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy',
588-
'_ibl_wheelMoves.peakAmplitude.npy')
588+
'_ibl_wheelMoves.peakAmplitude.npy', None)
589589
var_names = ('goCueTrigger_times', 'stimOnTrigger_times',
590590
'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times',
591591
'errorCue_times', 'itiIn_times', 'stimFreeze_times', 'stimOff_times',
592592
'valveOpen_times', 'phase', 'position', 'quiescence', 'table',
593593
'wheel_timestamps', 'wheel_position',
594-
'wheelMoves_intervals', 'wheelMoves_peakAmplitude')
594+
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times')
595595

596596
bpod_rsync_fields = ('intervals', 'response_times', 'goCueTrigger_times',
597597
'stimOnTrigger_times', 'stimOffTrigger_times',

ibllib/io/extractors/fibrephotometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _extract(self, light_source_map=None, collection=None, regions=None, **kwarg
225225
regions = regions or [k for k in fp_data['raw'].keys() if 'Region' in k]
226226
out_df = fp_data['raw'].filter(items=regions, axis=1).sort_index(axis=1)
227227
out_df['times'] = ts
228-
out_df['wavelength'] = np.NaN
228+
out_df['wavelength'] = np.nan
229229
out_df['name'] = ''
230230
out_df['color'] = ''
231231
# Extract channel index

ibllib/io/extractors/mesoscope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import numpy as np
55
from scipy.signal import find_peaks
66
import one.alf.io as alfio
7-
from one.util import ensure_list
87
from one.alf.files import session_path_parts
8+
from iblutil.util import ensure_list
99
import matplotlib.pyplot as plt
1010
from packaging import version
1111

ibllib/io/extractors/opto_trials.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class LaserBool(BaseBpodTrialsExtractor):
1616
def _extract(self, **kwargs):
1717
_logger.info('Extracting laser datasets')
1818
# reference pybpod implementation
19-
lstim = np.array([float(t.get('laser_stimulation', np.NaN)) for t in self.bpod_trials])
20-
lprob = np.array([float(t.get('laser_probability', np.NaN)) for t in self.bpod_trials])
19+
lstim = np.array([float(t.get('laser_stimulation', np.nan)) for t in self.bpod_trials])
20+
lprob = np.array([float(t.get('laser_probability', np.nan)) for t in self.bpod_trials])
2121

2222
# Karolina's choice world legacy implementation - from Slack message:
2323
# it is possible that some versions I have used:
@@ -30,9 +30,9 @@ def _extract(self, **kwargs):
3030
# laserOFF_trials=(optoOUT ==0);
3131
if 'PROBABILITY_OPTO' in self.settings.keys() and np.all(np.isnan(lstim)):
3232
lprob = np.zeros_like(lprob) + self.settings['PROBABILITY_OPTO']
33-
lstim = np.array([float(t.get('opto_ON_time', np.NaN)) for t in self.bpod_trials])
33+
lstim = np.array([float(t.get('opto_ON_time', np.nan)) for t in self.bpod_trials])
3434
if np.all(np.isnan(lstim)):
35-
lstim = np.array([float(t.get('optoOUT', np.NaN)) for t in self.bpod_trials])
35+
lstim = np.array([float(t.get('optoOUT', np.nan)) for t in self.bpod_trials])
3636
lstim[lstim == 255] = 1
3737
else:
3838
lstim[~np.isnan(lstim)] = 1

0 commit comments

Comments
 (0)