Skip to content

Commit ef58861

Browse files
committed
add raw ephys plots to task pipeline
1 parent 380c5e3 commit ef58861

File tree

3 files changed

+102
-39
lines changed

3 files changed

+102
-39
lines changed

ibllib/pipes/ephys_preprocessing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ibllib.qc.camera import run_all_qc as run_camera_qc
2828
from ibllib.qc.dlc import DlcQC
2929
from ibllib.dsp import rms
30-
from ibllib.plots.figures import dlc_qc_plot, BehaviourPlots, LfpPlots, ApPlots
30+
from ibllib.plots.figures import dlc_qc_plot, BehaviourPlots, LfpPlots, ApPlots, BadChannelsAp
3131
from ibllib.plots.figures import SpikeSorting as SpikeSortingPlots
3232
from ibllib.plots.snapshot import ReportSnapshot
3333
from brainbox.behavior.dlc import likelihood_threshold, get_licks, get_pupil_diameter, get_smooth_pupil_diameter
@@ -138,6 +138,9 @@ def _run(self, overwrite=False):
138138
plot_task = LfpPlots(pid, session_path=self.session_path, one=self.one)
139139
_ = plot_task.run()
140140
self.plot_tasks.append(plot_task)
141+
plot_task = BadChannelsAp(pid, session_path=self.session_path, one=self.one)
142+
_ = plot_task.run()
143+
self.plot_tasks.append(plot_task)
141144

142145
except AssertionError:
143146
_logger.error(traceback.format_exc())

ibllib/plots/figures.py

Lines changed: 96 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
import one.alf.io as alfio
1818
from one.alf.exceptions import ALFObjectNotFound
1919
from ibllib.io.video import get_video_frame, url_from_eid
20+
from ibllib.io import spikeglx
2021
from brainbox.plot import driftmap
22+
from brainbox.io.spikeglx import stream
2123
from brainbox.behavior.dlc import SAMPLING, plot_trace_on_frame, plot_wheel_position, plot_lick_hist, \
2224
plot_lick_raster, plot_motion_energy_hist, plot_speed_hist, plot_pupil_diameter_hist
2325
from brainbox.ephys_plots import image_lfp_spectrum_plot, image_rms_plot, plot_brain_regions
2426
from brainbox.io.one import load_spike_sorting_fast
2527
from brainbox.behavior import training
2628
from iblutil.numerical import ismember
29+
from ibllib.plots.misc import Density
2730

2831

2932
logger = logging.getLogger('ibllib')
@@ -398,6 +401,14 @@ def _run(self):
398401
"""runs for initiated PID, streams data, destripe and check bad channels"""
399402
assert self.pid
400403
self.eqcs = []
404+
T0 = 60 * 30
405+
SNAPSHOT_LABEL = "raw_ephys_bad_channels"
406+
output_files = list(self.output_directory.glob(f'{SNAPSHOT_LABEL}*'))
407+
if len(output_files) == 4:
408+
return output_files
409+
410+
self.output_directory.mkdir(exist_ok=True, parents=True)
411+
401412
if self.location != 'server':
402413
self.histology_status = self.get_histology_status()
403414
electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
@@ -406,74 +417,124 @@ def _run(self):
406417
electrodes['ibr'] = ismember(electrodes['atlas_id'], self.brain_regions.id)[1]
407418
electrodes['acronym'] = self.brain_regions.acronym[electrodes['ibr']]
408419
electrodes['name'] = self.brain_regions.name[electrodes['ibr']]
420+
electrodes['title'] = self.histology_status
409421
else:
410422
electrodes = None
423+
424+
sr, t0 = stream(self.pid, T0, nsecs=1, one=self.one)
425+
raw = sr[:, :-sr.nsync].T
411426
else:
412427
electrodes = None
428+
ap_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*ap.*bin'), None)
429+
if ap_file is not None:
430+
sr = spikeglx.Reader(ap_file)
431+
raw = sr[int((sr.fs * T0)):int((sr.fs * (T0 + 1))), :-sr.nsync].T
432+
else:
433+
return []
413434

414-
SNAPSHOT_LABEL = "raw_ephys_bad_channels"
415-
eid, pname = self.one.pid2eid(self.pid)
416-
output_files = list(self.output_directory.glob(f'{SNAPSHOT_LABEL}*'))
417-
if len(output_files) == 4:
418-
return output_files
419-
self.output_directory.mkdir(exist_ok=True, parents=True)
420-
from brainbox.io.spikeglx import stream
421-
T0 = 60 * 30
422-
sr, t0 = stream(self.pid, T0, nsecs=1, one=self.one)
423-
raw = sr[:, :-sr.nsync].T
424435
channel_labels, channel_features = voltage.detect_bad_channels(raw, sr.fs)
425436
_, eqcs, output_files = ephys_bad_channels(
426437
raw=raw, fs=sr.fs, channel_labels=channel_labels, channel_features=channel_features, channels=electrodes,
427-
title=SNAPSHOT_LABEL, destripe=True, save_dir=self.output_directory, br=self.brain_regions)
438+
title=SNAPSHOT_LABEL, destripe=True, save_dir=self.output_directory, br=self.brain_regions, pid_info=self.pid_label)
428439
self.eqcs = eqcs
429440
return output_files
430441

431442

432-
def ephys_bad_channels(raw, fs, channel_labels, channel_features, channels=None, title="ephys_bad_channels", save_dir=None,
433-
destripe=False, eqcs=None, br=None):
443+
def ephys_bad_channels(raw, fs, channel_labels, channel_features, channels=None, title="ephys_bad_channels",
444+
save_dir=None, destripe=False, eqcs=None, br=None, pid_info=None, plot_backend='matplotlib'):
434445
nc, ns = raw.shape
435446
rl = ns / fs
447+
448+
def gain2level(gain):
449+
return 10 ** (gain / 20) * 4 * np.array([-1, 1])
450+
436451
if fs >= 2600: # AP band
437452
ylim_rms = [0, 100]
438453
ylim_psd_hf = [0, 0.1]
439454
eqc_xrange = [450, 500]
440455
butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'}
441456
eqc_gain = - 90
457+
eqc_levels = gain2level(eqc_gain)
442458
else:
443459
# we are working with the LFP
444460
ylim_rms = [0, 1000]
445461
ylim_psd_hf = [0, 1]
446462
eqc_xrange = [450, 950]
447463
butter_kwargs = {'N': 3, 'Wn': np.array([2, 125]) / fs * 2, 'btype': 'bandpass'}
448464
eqc_gain = - 78
465+
eqc_levels = gain2level(eqc_gain)
449466

450467
inoisy = np.where(channel_labels == 2)[0]
451468
idead = np.where(channel_labels == 1)[0]
452469
ioutside = np.where(channel_labels == 3)[0]
453-
from viewspikes.gui import viewephys
454470

455471
# display voltage traces
456472
eqcs = [] if eqcs is None else eqcs
457473
# butterworth, for display only
458474
sos = scipy.signal.butter(**butter_kwargs, output='sos')
459475
butt = scipy.signal.sosfiltfilt(sos, raw)
460-
eqcs.append(viewephys(butt, fs=fs, channels=channels, title='highpass', br=br))
461-
if destripe:
462-
dest = voltage.destripe(raw, fs=fs, channel_labels=channel_labels)
463-
eqcs.append(viewephys(dest, fs=fs, channels=channels, title='destripe', br=br))
464-
eqcs.append(viewephys((butt - dest), fs=fs, channels=channels, title='difference', br=br))
465-
466-
for eqc in eqcs:
467-
y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
468-
eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside')
469-
y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
470-
eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy')
471-
y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
472-
eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead')
476+
477+
if plot_backend == 'matplotlib':
478+
_, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
479+
eqcs.append(Density(butt, fs=fs, taxis=1, ax=axs[0], title='highpass', vmin=eqc_levels[0], vmax=eqc_levels[1],
480+
cmap='Greys'))
481+
482+
if destripe:
483+
dest = voltage.destripe(raw, fs=fs, channel_labels=channel_labels)
484+
_, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
485+
eqcs.append(Density(dest, fs=fs, taxis=1, ax=axs[0], title='destripe', vmin=eqc_levels[0], vmax=eqc_levels[1],
486+
cmap='Greys'))
487+
_, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
488+
eqcs.append(Density((butt - dest), fs=fs, taxis=1, ax=axs[0], title='difference', vmin=eqc_levels[0],
489+
vmax=eqc_levels[1], cmap='Greys'))
490+
491+
for eqc in eqcs:
492+
y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
493+
eqc.ax.scatter(x.flatten(), y.flatten(), c='goldenrod', s=4)
494+
y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
495+
eqc.ax.scatter(x.flatten(), y.flatten(), c='r', s=4)
496+
y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
497+
eqc.ax.scatter(x.flatten(), y.flatten(), c='b', s=4)
498+
499+
eqc.ax.set_xlim(*eqc_xrange)
500+
eqc.ax.set_ylim(0, nc)
501+
eqc.ax.set_ylabel('Channel index')
502+
eqc.ax.set_title(f'{pid_info}_{eqc.title}')
503+
set_axis_label_size(eqc.ax)
504+
505+
ax = eqc.figure.axes[1]
506+
if channels is not None:
507+
chn_title = channels.get('title', None)
508+
plot_brain_regions(channels['atlas_id'], brain_regions=br, display=True, ax=ax,
509+
title=chn_title)
510+
set_axis_label_size(ax)
511+
else:
512+
remove_axis_outline(ax)
513+
else:
514+
from viewspikes.gui import viewephys # noqa
515+
eqcs.append(viewephys(butt, fs=fs, channels=channels, title='highpass', br=br))
516+
517+
if destripe:
518+
dest = voltage.destripe(raw, fs=fs, channel_labels=channel_labels)
519+
eqcs.append(viewephys(dest, fs=fs, channels=channels, title='destripe', br=br))
520+
eqcs.append(viewephys((butt - dest), fs=fs, channels=channels, title='difference', br=br))
521+
522+
for eqc in eqcs:
523+
y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
524+
eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside')
525+
y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
526+
eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy')
527+
y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
528+
eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead')
529+
530+
eqcs[0].ctrl.set_gain(eqc_gain)
531+
eqcs[0].resize(1960, 1200)
532+
eqcs[0].viewBox_seismic.setXRange(*eqc_xrange)
533+
eqcs[0].viewBox_seismic.setYRange(0, nc)
534+
eqcs[0].ctrl.propagate()
535+
473536
# display features
474537
fig, axs = plt.subplots(2, 2, sharex=True, figsize=[16, 9], tight_layout=True)
475-
476-
# fig.suptitle(f"pid:{pid}, \n eid:{eid}, \n {one.eid2path(eid).parts[-3:]}, {pname}")
477538
fig.suptitle(title)
478539
axs[0, 0].plot(channel_features['rms_raw'] * 1e6)
479540
axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=ylim_rms)
@@ -499,18 +560,16 @@ def ephys_bad_channels(raw, fs, channel_labels, channel_features, channels=None,
499560
axs[1, 1].plot(inoisy, inoisy * 0 + fs / 4, 'xr')
500561
axs[1, 1].plot(ioutside, ioutside * 0 + fs / 4, 'xy')
501562

502-
eqcs[0].ctrl.set_gain(eqc_gain)
503-
eqcs[0].resize(1960, 1200)
504-
eqcs[0].viewBox_seismic.setXRange(*eqc_xrange)
505-
eqcs[0].viewBox_seismic.setYRange(0, nc)
506-
eqcs[0].ctrl.propagate()
507-
508563
if save_dir is not None:
509564
output_files = [Path(save_dir).joinpath(f"{title}.png")]
510565
fig.savefig(output_files[0])
511566
for eqc in eqcs:
512-
output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png"))
513-
eqc.grab().save(str(output_files[-1]))
567+
if plot_backend == 'matplotlib':
568+
output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.title}.png"))
569+
eqc.figure.savefig(str(output_files[-1]))
570+
else:
571+
output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png"))
572+
eqc.grab().save(str(output_files[-1]))
514573
return fig, eqcs, output_files
515574
else:
516575
return fig, eqcs

ibllib/plots/misc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def insert_zeros(trace):
7272

7373

7474
class Density:
75-
def __init__(self, w, fs=1, cmap='bone', ax=None, taxis=0, **kwargs):
75+
def __init__(self, w, fs=1, cmap='bone', ax=None, taxis=0, title=None, **kwargs):
7676
"""
7777
Matplotlib display of traces as a density display
7878
@@ -101,6 +101,7 @@ def __init__(self, w, fs=1, cmap='bone', ax=None, taxis=0, **kwargs):
101101
ax.set_xlabel(xlabel)
102102
self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press)
103103
self.ax = ax
104+
self.title = title or None
104105

105106
def on_key_press(self, event):
106107
if event.key == 'ctrl+a':

0 commit comments

Comments
 (0)