Skip to content

Commit 03fc88f

Browse files
authored
Merge pull request #403 from int-brain-lab/multiparts
lock in GPU tasks - local data handler
2 parents 95a0553 + 30fd082 commit 03fc88f

File tree

9 files changed

+239
-27
lines changed

9 files changed

+239
-27
lines changed

brainbox/ephys_plots.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
from matplotlib import cm
3-
3+
import matplotlib.pyplot as plt
44
from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line,
55
plot_image, plot_probe, plot_scatter, arrange_channels2banks)
66
from brainbox.processing import bincount2D, compute_cluster_average
7+
from ibllib.atlas.regions import BrainRegions
78

89

910
def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 300),
@@ -372,3 +373,100 @@ def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, d
372373
fig, ax = plot_line(data.convert2dict())
373374
return data.convert2dict(), fig, ax
374375
return data
376+
377+
378+
def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None):
379+
"""
380+
Plot brain regions along probe, if channel depths is provided will plot along depth otherwise along channel idx
381+
:param channel_ids: atlas ids for each channel
382+
:param channel_depths: depth along probe for each channel
383+
:param brain_regions: BrainRegions object
384+
:param display: whether to output plot
385+
:param ax: axis to plot on
386+
:return:
387+
"""
388+
389+
if channel_depths is not None:
390+
assert channel_ids.shape[0] == channel_depths.shape[0]
391+
392+
br = brain_regions or BrainRegions()
393+
394+
region_info = br.get(channel_ids)
395+
boundaries = np.where(np.diff(region_info.id) != 0)[0]
396+
boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1]
397+
398+
regions = np.c_[boundaries[0:-1], boundaries[1:]]
399+
if channel_depths is not None:
400+
regions = channel_depths[regions]
401+
region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]]
402+
region_colours = region_info.rgb[boundaries[1:]]
403+
404+
if display:
405+
if ax is None:
406+
fig, ax = plt.subplots()
407+
408+
for reg, col in zip(regions, region_colours):
409+
height = np.abs(reg[1] - reg[0])
410+
color = col / 255
411+
ax.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w')
412+
ax.set_yticks(region_labels[:, 0].astype(int))
413+
ax.yaxis.set_tick_params(labelsize=8)
414+
ax.get_xaxis().set_visible(False)
415+
ax.set_yticklabels(region_labels[:, 1])
416+
ax.spines['right'].set_visible(False)
417+
ax.spines['top'].set_visible(False)
418+
ax.spines['bottom'].set_visible(False)
419+
420+
return fig, ax
421+
else:
422+
return regions, region_labels, region_colours
423+
424+
425+
def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None,
426+
display=False, cmap='hot'):
427+
"""
428+
Plot cumulative amplitude of spikes across depth
429+
:param spike_amps:
430+
:param spike_depths:
431+
:param spike_times:
432+
:param n_amp_bins: number of amplitude bins to use
433+
:param d_bin: the value of the depth bins in um (default is 40 um)
434+
:param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps
435+
:param d_range: depth range to use, by default [0, 3840]
436+
:param display: whether or not to display plot
437+
:param cmap:
438+
:return:
439+
"""
440+
441+
amp_range = amp_range or np.quantile(spike_amps, (0, 0.9))
442+
amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
443+
d_range = d_range or [0, 3840]
444+
depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin)
445+
t_bin = np.max(spike_times)
446+
447+
def histc(x, bins):
448+
map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs.
449+
res = np.zeros(bins.shape)
450+
451+
for el in map_to_bins:
452+
res[el - 1] += 1 # Increment appropriate bin.
453+
return res
454+
455+
cdfs = np.empty((len(depth_bins) - 1, n_amp_bins))
456+
for d in range(len(depth_bins) - 1):
457+
spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1])
458+
h = histc(spike_amps[spikes], amp_bins) / t_bin
459+
hcsum = np.cumsum(h[::-1])
460+
cdfs[d, :] = hcsum[::-1]
461+
462+
cdfs[cdfs == 0] = np.nan
463+
464+
data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap)
465+
data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)',
466+
ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)')
467+
468+
if display:
469+
fig, ax = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]})
470+
return data.convert2dict(), fig, ax
471+
472+
return data

brainbox/io/spikeglx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ def stream(pid, t0, nsecs=1, one=None, cache_folder=None, remove_cached=False, t
128128
samples_folder = Path(one.alyx._par.CACHE_DIR).joinpath('cache', typ)
129129

130130
eid, pname = one.pid2eid(pid)
131-
cbin_rec = one.list_datasets(eid, collection=f"*{pname}", filename='*ap.*bin', details=True)
132-
ch_rec = one.list_datasets(eid, collection=f"*{pname}", filename='*ap.ch', details=True)
133-
meta_rec = one.list_datasets(eid, collection=f"*{pname}", filename='*ap.meta', details=True)
131+
cbin_rec = one.list_datasets(eid, collection=f"*{pname}", filename=f'*{typ}.*bin', details=True)
132+
ch_rec = one.list_datasets(eid, collection=f"*{pname}", filename=f'*{typ}.ch', details=True)
133+
meta_rec = one.list_datasets(eid, collection=f"*{pname}", filename=f'*{typ}.meta', details=True)
134134
ch_file = one._download_datasets(ch_rec)[0]
135135
one._download_datasets(meta_rec)[0]
136136

brainbox/task/passive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def get_stim_aligned_activity(stim_events, spike_times, spike_depths, z_score_fl
206206
base_intervals = np.c_[stim_times - base_stim, stim_times - pre_stim]
207207
out_intervals = stim_intervals[:, 1] > times[-1]
208208

209-
idx_stim = np.searchsorted(times, stim_intervals)[np.invert(out_intervals)]
210-
idx_base = np.searchsorted(times, base_intervals)[np.invert(out_intervals)]
209+
idx_stim = np.searchsorted(times, stim_intervals, side='right')[np.invert(out_intervals)]
210+
idx_base = np.searchsorted(times, base_intervals, side='right')[np.invert(out_intervals)]
211211

212212
stim_trials = np.zeros((depths.shape[0], n_bins, idx_stim.shape[0]))
213213
noise_trials = np.zeros((depths.shape[0], n_bins_base, idx_stim.shape[0]))

ibllib/atlas/regions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ def _mapping_from_regions_list(self, new_map, lateralize=False):
149149
mapind = mapind[iregion]
150150
return mapind
151151

152+
def remap(self, region_ids, source_map='Allen', target_map='Beryl'):
153+
"""
154+
Remap atlas regions ids from source map to target map
155+
:param region_ids: atlas ids to map
156+
:param source_map: map name which original region_ids are in
157+
:param target_map: map name onto which to map
158+
:return:
159+
"""
160+
_, inds = ismember(region_ids, self.id[self.mappings[source_map]])
161+
return self.id[self.mappings[target_map][inds]]
162+
152163

153164
def regions_from_allen_csv():
154165
"""

ibllib/oneibl/data_handlers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import abc
88
from time import time
99

10-
from one.api import ONE
1110
from one.util import filter_datasets
1211
from one.alf.files import add_uuid_string
1312
from iblutil.io.parquet import np2str
@@ -27,10 +26,10 @@ def __init__(self, session_path, signature, one=None):
2726
:param one: ONE instance
2827
"""
2928
self.session_path = session_path
30-
self.one = one or ONE()
3129
self.signature = signature
30+
self.one = one
3231

33-
def setup(self):
32+
def setUp(self):
3433
"""
3534
Function to optionally overload to download required data to run task
3635
:return:
@@ -42,7 +41,8 @@ def getData(self):
4241
Finds the datasets required for task based on input signatures
4342
:return:
4443
"""
45-
44+
if self.one is None:
45+
return
4646
session_datasets = self.one.list_datasets(self.one.path2eid(self.session_path), details=True)
4747
df = pd.DataFrame(columns=self.one._cache.datasets.columns)
4848
for file in self.signature['input_files']:
@@ -72,6 +72,17 @@ def cleanUp(self):
7272
pass
7373

7474

75+
class LocalDataHandler(DataHandler):
76+
def __init__(self, session_path, signatures, one=None):
77+
"""
78+
Data handler for running tasks locally, with no architecture or db connection
79+
:param session_path: path to session
80+
:param signature: input and output file signatures
81+
:param one: ONE instance
82+
"""
83+
super().__init__(session_path, signatures, one=one)
84+
85+
7586
class ServerDataHandler(DataHandler):
7687
def __init__(self, session_path, signatures, one=None):
7788
"""

ibllib/pipes/tasks.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,30 @@
66
import time
77
from _collections import OrderedDict
88
import traceback
9+
import json
910

1011
from graphviz import Digraph
1112

1213
from ibllib.misc import version
13-
import one.params
1414
from ibllib.oneibl import data_handlers
15-
15+
import one.params
16+
from one.api import ONE
1617

1718
_logger = logging.getLogger('ibllib')
1819

1920

2021
class Task(abc.ABC):
21-
log = ""
22-
cpu = 1
23-
gpu = 0
22+
log = "" # place holder to keep the log of the task for registratoin
23+
cpu = 1 # CPU resource
24+
gpu = 0 # GPU resources: as of now, either 0 or 1
2425
io_charge = 5 # integer percentage
2526
priority = 30 # integer percentage, 100 means highest priority
2627
ram = 4 # RAM needed to run (Go)
2728
one = None # one instance (optional)
28-
level = 0
29-
outputs = None
29+
level = 0 # level in the pipeline hierarchy: level 0 means there is no parent task
30+
outputs = None # place holder for a list of Path containing output files
3031
time_elapsed_secs = None
31-
time_out_secs = None
32+
time_out_secs = 3600 * 2 # time-out after which a task is considered dead
3233
version = version.ibllib()
3334
signature = {'input_files': [], 'output_files': []} # list of tuples (filename, collection, required_flag)
3435
force = False # whether or not to re-download missing input files on local server if not present
@@ -69,6 +70,11 @@ def run(self, **kwargs):
6970
wraps the _run() method with
7071
- error management
7172
- logging to variable
73+
- writing a lock file if the GPU is used
74+
- labels the status property of the object. The status value is labeled as:
75+
0: Complete
76+
-1: Errored
77+
-2: Didn't run as a lock was encountered
7278
"""
7379
# if taskid of one properties are not available, local run only without alyx
7480
use_alyx = self.one is not None and self.taskid is not None
@@ -91,17 +97,20 @@ def run(self, **kwargs):
9197
# setup
9298
setup = self.setUp(**kwargs)
9399
_logger.info(f"Setup value is: {setup}")
100+
self.status = 0
94101
if not setup:
95102
# case where outputs are present but don't have input files locally to rerun task
96103
# label task as complete
97-
self.status = 0
98104
_, self.outputs = self.assert_expected_outputs()
99-
100105
else:
101106
# run task
102-
self.status = 0
103107
start_time = time.time()
104108
try:
109+
if self.gpu >= 1:
110+
if not self._creates_lock():
111+
self.status = -2
112+
_logger.info(f"Job {self.__class__} exited as a lock was found")
113+
return
105114
self.outputs = self._run(**kwargs)
106115
_logger.info(f"Job {self.__class__} complete")
107116
except BaseException:
@@ -169,7 +178,6 @@ def setUp(self, **kwargs):
169178
:param kwargs:
170179
:return:
171180
"""
172-
173181
if self.location == 'server':
174182
self.get_signatures(**kwargs)
175183

@@ -196,7 +204,6 @@ def setUp(self, **kwargs):
196204
# TODO in future should raise error if even after downloading don't have the correct files
197205
self.assert_expected_inputs(raise_error=False)
198206
return True
199-
200207
else:
201208
self.data_handler = self.get_data_handler()
202209
self.data_handler.setUp()
@@ -206,9 +213,10 @@ def setUp(self, **kwargs):
206213

207214
def tearDown(self):
208215
"""
209-
Function to optionally overload to check results
216+
Function after runs()
210217
"""
211-
pass
218+
if self.gpu >= 1:
219+
self._lock_file_path().unlink()
212220

213221
def cleanUp(self):
214222
"""
@@ -270,7 +278,9 @@ def get_data_handler(self, location=None):
270278
:return:
271279
"""
272280
location = location or self.location
273-
281+
if location == 'local':
282+
return data_handlers.LocalDataHandler(self.session_path, self.signature, one=self.one)
283+
self.one = self.one or ONE()
274284
if location == 'server':
275285
dhandler = data_handlers.ServerDataHandler(self.session_path, self.signature, one=self.one)
276286
elif location == 'serverglobus':
@@ -281,9 +291,49 @@ def get_data_handler(self, location=None):
281291
dhandler = data_handlers.RemoteAwsDataHandler(self.session_path, self.signature, one=self.one)
282292
elif location == 'SDSC':
283293
dhandler = data_handlers.SDSCDataHandler(self, self.session_path, self.signature, one=self.one)
284-
285294
return dhandler
286295

296+
@staticmethod
297+
def make_lock_file(taskname="", time_out_secs=7200):
298+
"""Creates a GPU lock file with a timeout of"""
299+
d = {'start': time.time(), 'name': taskname, 'time_out_secs': time_out_secs}
300+
with open(Task._lock_file_path(), 'w+') as fid:
301+
json.dump(d, fid)
302+
return d
303+
304+
@staticmethod
305+
def _lock_file_path():
306+
"""the lock file is in ~/.one/gpu.lock"""
307+
folder = Path.home().joinpath('.one')
308+
folder.mkdir(exist_ok=True)
309+
return folder.joinpath('gpu.lock')
310+
311+
def _make_lock_file(self):
312+
"""creates a lock file with the current time"""
313+
return Task.make_lock_file(self.name, self.time_out_secs)
314+
315+
def is_locked(self):
316+
"""Checks if there is a lock file for this given task"""
317+
lock_file = self._lock_file_path()
318+
if not lock_file.exists():
319+
return False
320+
321+
with open(lock_file) as fid:
322+
d = json.load(fid)
323+
now = time.time()
324+
if (now - d['start']) > d['time_out_secs']:
325+
lock_file.unlink()
326+
return False
327+
else:
328+
return True
329+
330+
def _creates_lock(self):
331+
if self.is_locked():
332+
return False
333+
else:
334+
self._make_lock_file()
335+
return True
336+
287337

288338
class Pipeline(abc.ABC):
289339
"""

ibllib/tests/test_atlas.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def test_mappings_not_lateralized(self):
5757
inds_[0] = 0
5858
assert np.all(inds == inds_)
5959

60+
def test_remap(self):
61+
# Test mapping atlas ids from one map to another
62+
atlas_id = np.array([463, 685]) # CA3 and PO
63+
cosmos_atlas_id = self.brs.remap(atlas_id, source_map='Allen', target_map='Cosmos')
64+
expectd_cosmos_id = [1089, 549] # HPF and TH
65+
assert np.all(cosmos_atlas_id == expectd_cosmos_id)
66+
6067

6168
class TestAtlasSlicesConversion(unittest.TestCase):
6269

0 commit comments

Comments
 (0)