Skip to content

Commit e657db6

Browse files
committed
Merge branch 'release/2.5.0'
2 parents cab64bd + 850c3c4 commit e657db6

22 files changed

+901
-149
lines changed

ibllib/dsp/fourier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def convolve(x, w, mode='full'):
2020
ns = ns_optim_fft(nsx + nsw)
2121
x_ = np.concatenate((x, np.zeros([*x.shape[:-1], ns - nsx], dtype=x.dtype)), axis=-1)
2222
w_ = np.concatenate((w, np.zeros([*w.shape[:-1], ns - nsw], dtype=w.dtype)), axis=-1)
23-
xw = np.fft.irfft(np.fft.rfft(x_, axis=-1) * np.fft.rfft(w_, axis=-1), axis=-1)
23+
xw = np.real(np.fft.irfft(np.fft.rfft(x_, axis=-1) * np.fft.rfft(w_, axis=-1), axis=-1))
2424
xw = xw[..., :(nsx + nsw)] # remove 0 padding
2525
if mode == 'full':
2626
return xw

ibllib/dsp/voltage.py

Lines changed: 329 additions & 62 deletions
Large diffs are not rendered by default.

ibllib/ephys/ephysqc.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
import pandas as pd
10-
from scipy import signal
10+
from scipy import signal, stats
1111
from tqdm import tqdm
1212
import one.alf.io as alfio
1313
from iblutil.util import Bunch
@@ -105,7 +105,10 @@ def _compute_metrics_array(raw, fs, h):
105105
rms_pre_proc = dsp.rms(destripe)
106106
detections = spikes.detection(data=destripe.T, fs=fs, h=h, detect_threshold=SPIKE_THRESHOLD_UV * 1e-6)
107107
spike_rate = np.bincount(detections.trace, minlength=raw.shape[0]).astype(np.float32)
108-
return rms_raw, rms_pre_proc, spike_rate
108+
channel_labels, _ = dsp.voltage.detect_bad_channels(raw, fs=fs)
109+
_, psd = signal.welch(destripe, fs=fs, window='hanning', nperseg=WELCH_WIN_LENGTH_SAMPLES,
110+
detrend='constant', return_onesided=True, scaling='density', axis=-1)
111+
return rms_raw, rms_pre_proc, spike_rate, channel_labels, psd
109112

110113
def run(self, update: bool = False, overwrite: bool = True, stream: bool = None, **kwargs) -> (str, dict):
111114
"""
@@ -124,14 +127,18 @@ def run(self, update: bool = False, overwrite: bool = True, stream: bool = None,
124127
self.load_data()
125128
qc_files = []
126129
# If ap meta file present, calculate median RMS per channel before and after destriping
127-
# TODO: This should go a a separate function once we have a spikeglx.Streamer that behaves like the Reader
130+
# NB: ideally this should go a a separate function once we have a spikeglx.Streamer that behaves like the Reader
128131
if self.data.ap_meta:
129-
rms_file = self.probe_path.joinpath("_iblqc_ephysChannels.apRMS.npy")
130-
spike_rate_file = self.probe_path.joinpath("_iblqc_ephysChannels.rawSpikeRates.npy")
131-
if all([rms_file.exists(), spike_rate_file.exists()]) and not overwrite:
132+
files = {'rms': self.probe_path.joinpath("_iblqc_ephysChannels.apRMS.npy"),
133+
'spike_rate': self.probe_path.joinpath("_iblqc_ephysChannels.rawSpikeRates.npy"),
134+
'channel_labels': self.probe_path.joinpath("_iblqc_ephysChannels.labels.npy"),
135+
'ap_freqs': self.probe_path.joinpath("_iblqc_ephysSpectralDensityAP.freqs.npy"),
136+
'ap_power': self.probe_path.joinpath("_iblqc_ephysSpectralDensityAP.power.npy"),
137+
}
138+
if all([files[k].exists() for k in files]) and not overwrite:
132139
_logger.warning(f'RMS map already exists for .ap data in {self.probe_path}, skipping. '
133140
f'Use overwrite option.')
134-
median_rms = np.load(rms_file)
141+
results = {k: np.load(files[k]) for k in files}
135142
else:
136143
rl = self.data.ap_meta.fileTimeSecs
137144
nsync = len(spikeglx._get_sync_trace_indices_from_meta(self.data.ap_meta))
@@ -145,14 +152,18 @@ def run(self, update: bool = False, overwrite: bool = True, stream: bool = None,
145152
raise ValueError("Wrong Neuropixel channel mapping used - ABORT")
146153
t0s = np.arange(TMIN, rl - SAMPLE_LENGTH, BATCHES_SPACING)
147154
all_rms = np.zeros((2, nc, t0s.shape[0]))
148-
all_srs = np.zeros((nc, t0s.shape[0]))
155+
all_srs, channel_ok = (np.zeros((nc, t0s.shape[0])) for _ in range(2))
156+
psds = np.zeros((nc, dsp.fscale(WELCH_WIN_LENGTH_SAMPLES, 1, one_sided=True).size))
149157
# If the ap.bin file is not present locally, stream it
150158
if self.data.ap is None and self.stream is True:
151159
_logger.warning(f'Streaming .ap data to compute RMS samples for probe {self.pid}')
152160
for i, t0 in enumerate(tqdm(t0s)):
153161
sr, _ = sglx_streamer(self.pid, t0=t0, nsecs=1, one=self.one, remove_cached=True)
154162
raw = sr[:, :-nsync].T
155-
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i] = self._compute_metrics_array(raw, sr.fs, h)
163+
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i], channel_ok[:, i], psd =\
164+
self._compute_metrics_array(raw, sr.fs, h)
165+
psds += psd
166+
fs = sr.fs
156167
elif self.data.ap is None and self.stream is not True:
157168
_logger.warning('Raw .ap data is not available locally. Run with stream=True in order to stream '
158169
'data for calculating RMS samples.')
@@ -161,23 +172,27 @@ def run(self, update: bool = False, overwrite: bool = True, stream: bool = None,
161172
for i, t0 in enumerate(t0s):
162173
sl = slice(int(t0 * self.data.ap.fs), int((t0 + SAMPLE_LENGTH) * self.data.ap.fs))
163174
raw = self.data.ap[sl, :-nsync].T
164-
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i] = self._compute_metrics_array(raw, self.data.ap.fs, h)
175+
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i], channel_ok[:, i], psd =\
176+
self._compute_metrics_array(raw, self.data.ap.fs, h)
177+
fs = self.data.ap.fs
178+
psds += psd
165179
# Calculate the median RMS across all samples per channel
166-
median_rms = np.median(all_rms, axis=-1)
167-
median_spike_rate = np.median(all_srs, axis=-1)
168-
np.save(rms_file, median_rms)
169-
np.save(spike_rate_file, median_spike_rate)
170-
qc_files.extend([rms_file, spike_rate_file])
171-
180+
results = {'rms': np.median(all_rms, axis=-1),
181+
'spike_rate': np.median(all_srs, axis=-1),
182+
'channel_labels': stats.mode(channel_ok, axis=1)[0],
183+
'ap_freqs': dsp.fscale(WELCH_WIN_LENGTH_SAMPLES, 1 / fs, one_sided=True),
184+
'ap_power': psds.T / len(t0s), # shape: (nfreqs, nchannels)
185+
}
186+
for k in files:
187+
np.save(files[k], results[k])
188+
qc_files.extend([files[k] for k in files])
172189
for p in [10, 90]:
173-
self.metrics[f'apRms_p{p}_raw'] = np.format_float_scientific(np.percentile(median_rms[0, :], p),
174-
precision=2)
175-
self.metrics[f'apRms_p{p}_proc'] = np.format_float_scientific(np.percentile(median_rms[1, :], p),
176-
precision=2)
190+
self.metrics[f'apRms_p{p}_raw'] = np.format_float_scientific(
191+
np.percentile(results['rms'][0, :], p), precision=2)
192+
self.metrics[f'apRms_p{p}_proc'] = np.format_float_scientific(
193+
np.percentile(results['rms'][1, :], p), precision=2)
177194
if update:
178195
self.update_extended_qc(self.metrics)
179-
# self.update(outcome)
180-
181196
# If lf meta and bin file present, run the old qc on LF data
182197
if self.data.lf_meta and self.data.lf:
183198
qc_files.extend(extract_rmsmap(self.data.lf, out_folder=self.probe_path, overwrite=overwrite))

ibllib/ephys/spikes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def _sr(ap_file):
143143
out_files.extend([f for f in out_path.glob("*.*") if
144144
f.name.startswith(('channels.', 'drift', 'clusters.', 'spikes.', 'templates.',
145145
'_kilosort_', '_phy_spikes_subset', '_ibl_log.info'))])
146+
# the QC files computed during spike sorting stay within the raw ephys data folder
147+
out_files.extend(list(ap_file.parent.glob('_iblqc_*AP.*.npy')))
146148
return out_files, 0
147149

148150

@@ -159,7 +161,7 @@ def ks2_to_alf(ks_path, bin_path, out_path, bin_file=None, ampfactor=1, label=No
159161
ac.convert(out_path, label=label, force=force, ampfactor=ampfactor)
160162

161163

162-
def ks2_to_tar(ks_path, out_path):
164+
def ks2_to_tar(ks_path, out_path, force=False):
163165
"""
164166
Compress output from kilosort 2 into tar file in order to register to flatiron and move to
165167
spikesorters/ks2_matlab/probexx path. Output file to register
@@ -199,7 +201,7 @@ def ks2_to_tar(ks_path, out_path):
199201
'whitening_mat_inv.npy']
200202

201203
out_file = Path(out_path).joinpath('_kilosort_raw.output.tar')
202-
if out_file.exists():
204+
if out_file.exists() and not force:
203205
_logger.info(f"Already converted ks2 to tar: for {ks_path}, skipping.")
204206
return [out_file]
205207

ibllib/ephys/sync_probes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def version3B(ses_path, display=True, type=None, tol=2.5):
140140
sync_probe = get_sync_fronts(ef.sync, ef.sync_map['imec_sync'])
141141
sr = _get_sr(ef)
142142
try:
143-
assert(sync_nidq.times.size == sync_probe.times.size)
143+
# we say that the number of pulses should be within 10 %
144+
assert(np.isclose(sync_nidq.times.size, sync_probe.times.size, rtol=0.1))
144145
except AssertionError:
145146
raise Neuropixel3BSyncFrontsNonMatching(f"{ses_path}")
146147
# if the qc of the diff finds anomalies, do not attempt to smooth the interp function

ibllib/oneibl/aws.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self, s3_bucket_name=None, one=None):
3232

3333
def _download_datasets(self, datasets):
3434

35+
files = []
3536
for _, d in datasets.iterrows():
3637
rel_file_path = Path(d['session_path']).joinpath(d['rel_path'])
3738
file_path = Path(self.one.cache_dir).joinpath(rel_file_path)
@@ -54,5 +55,8 @@ def _download_datasets(self, datasets):
5455
_logger.info(f'Downloading {aws_path} to {file_path}')
5556
self.bucket.download_file(aws_path, file_path.as_posix())
5657
_logger.debug(f'Complete. Time elapsed {time() - ts} for {file_path}')
58+
files.append(file_path)
5759
else:
5860
_logger.warning(f'{aws_path} not found on s3 bucket: {self.bucket.name}')
61+
62+
return files

ibllib/oneibl/data_handlers.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import abc
88
from time import time
99

10+
from one.api import ONE
1011
from one.util import filter_datasets
11-
from one.alf.files import add_uuid_string
12+
from one.alf.files import add_uuid_string, session_path_parts
1213
from iblutil.io.parquet import np2str
1314
from ibllib.oneibl.registration import register_dataset
1415
from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH
@@ -36,15 +37,17 @@ def setUp(self):
3637
"""
3738
pass
3839

39-
def getData(self):
40+
def getData(self, one=None):
4041
"""
4142
Finds the datasets required for task based on input signatures
4243
:return:
4344
"""
44-
if self.one is None:
45+
if self.one is None and one is None:
4546
return
46-
session_datasets = self.one.list_datasets(self.one.path2eid(self.session_path), details=True)
47-
df = pd.DataFrame(columns=self.one._cache.datasets.columns)
47+
48+
one = one or self.one
49+
session_datasets = one.list_datasets(one.path2eid(self.session_path), details=True)
50+
df = pd.DataFrame(columns=one._cache.datasets.columns)
4851
for file in self.signature['input_files']:
4952
df = df.append(filter_datasets(session_datasets, filename=file[0], collection=file[1],
5053
wildcards=True, assert_unique=False))
@@ -131,14 +134,21 @@ def __init__(self, session_path, signatures, one=None):
131134
else:
132135
self.lab = labs[0]
133136

134-
self.globus.add_endpoint(f'flatiron_{self.lab}')
137+
# For cortex lab we need to get the endpoint from the ibl alyx
138+
if self.lab == 'cortexlab':
139+
self.globus.add_endpoint(f'flatiron_{self.lab}', one=ONE(base_url='https://alyx.internationalbrainlab.org'))
140+
else:
141+
self.globus.add_endpoint(f'flatiron_{self.lab}')
135142

136143
def setUp(self):
137144
"""
138145
Function to download necessary data to run tasks using globus-sdk
139146
:return:
140147
"""
141-
df = super().getData()
148+
if self.lab == 'cortexlab':
149+
df = super().getData(one=ONE(base_url='https://alyx.internationalbrainlab.org'))
150+
else:
151+
df = super().getData()
142152

143153
if len(df) == 0:
144154
# If no datasets found in the cache only work off local file system do not attempt to download any missing data
@@ -225,24 +235,29 @@ def uploadData(self, outputs, version, **kwargs):
225235

226236

227237
class RemoteAwsDataHandler(DataHandler):
228-
def __init__(self, session_path, signature, one=None):
238+
def __init__(self, task, session_path, signature, one=None):
229239
"""
230240
Data handler for running tasks on remote compute node. Will download missing data from private ibl s3 AWS data bucket
231241
232242
:param session_path: path to session
233243
:param signature: input and output file signatures
234244
:param one: ONE instance
235245
"""
246+
from one.globus import Globus # noqa
236247
super().__init__(session_path, signature, one=one)
248+
self.task = task
237249
self.aws = AWS(one=self.one)
250+
self.globus = Globus(client_name='server')
251+
self.lab = session_path_parts(self.session_path, as_dict=True)['lab']
252+
self.globus.add_endpoint(f'flatiron_{self.lab}')
238253

239254
def setUp(self):
240255
"""
241256
Function to download necessary data to run tasks using AWS boto3
242257
:return:
243258
"""
244259
df = super().getData()
245-
self.aws._download_datasets(df)
260+
self.local_paths = self.aws._download_datasets(df)
246261

247262
def uploadData(self, outputs, version, **kwargs):
248263
"""
@@ -251,10 +266,70 @@ def uploadData(self, outputs, version, **kwargs):
251266
:param version: ibllib version
252267
:return: output info of registered datasets
253268
"""
269+
270+
# register datasets
254271
versions = super().uploadData(outputs, version)
255-
ftp_patcher = FTPPatcher(one=self.one)
256-
return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user,
257-
versions=versions, **kwargs)
272+
response = register_dataset(outputs, one=self.one, server_only=True, versions=versions, **kwargs)
273+
274+
# upload directly via globus
275+
source_paths = []
276+
target_paths = []
277+
collections = {}
278+
279+
for dset, out in zip(response, outputs):
280+
assert (Path(out).name == dset['name'])
281+
# set flag to false
282+
fr = next(fr for fr in dset['file_records'] if 'flatiron' in fr['data_repository'])
283+
collection = '/'.join(fr['relative_path'].split('/')[:-1])
284+
if collection in collections.keys():
285+
collections[collection].update({f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}})
286+
else:
287+
collections[collection] = {f'{dset["name"]}': {'fr_id': fr['id'], 'size': dset['file_size']}}
288+
289+
# Set all exists status to false for server file records
290+
self.one.alyx.rest('files', 'partial_update', id=fr['id'], data={'exists': False})
291+
292+
source_paths.append(out)
293+
target_paths.append(add_uuid_string(fr['relative_path'], dset['id']))
294+
295+
if len(target_paths) != 0:
296+
ts = time()
297+
for sp, tp in zip(source_paths, target_paths):
298+
_logger.info(f'Uploading {sp} to {tp}')
299+
self.globus.mv('local', f'flatiron_{self.lab}', source_paths, target_paths)
300+
_logger.debug(f'Complete. Time elapsed {time() - ts}')
301+
302+
for collection, files in collections.items():
303+
globus_files = self.globus.ls(f'flatiron_{self.lab}', collection, remove_uuid=True, return_size=True)
304+
file_names = [gl[0] for gl in globus_files]
305+
file_sizes = [gl[1] for gl in globus_files]
306+
307+
for name, details in files.items():
308+
try:
309+
idx = file_names.index(name)
310+
size = file_sizes[idx]
311+
if size == details['size']:
312+
# update the file record if sizes match
313+
self.one.alyx.rest('files', 'partial_update', id=details['fr_id'], data={'exists': True})
314+
else:
315+
_logger.warning(f'File {name} found on SDSC but sizes do not match')
316+
except ValueError:
317+
_logger.warning(f'File {name} not found on SDSC')
318+
319+
return response
320+
321+
# ftp_patcher = FTPPatcher(one=self.one)
322+
# return ftp_patcher.create_dataset(path=outputs, created_by=self.one.alyx.user,
323+
# versions=versions, **kwargs)
324+
325+
def cleanUp(self):
326+
"""
327+
Clean up, remove the files that were downloaded from globus once task has completed
328+
:return:
329+
"""
330+
if self.task.status == 0:
331+
for file in self.local_paths:
332+
os.unlink(file)
258333

259334

260335
class RemoteGlobusDataHandler(DataHandler):

0 commit comments

Comments
 (0)