1717import one .alf .io as alfio
1818from one .alf .exceptions import ALFObjectNotFound
1919from ibllib .io .video import get_video_frame , url_from_eid
20+ from ibllib .io import spikeglx
2021from brainbox .plot import driftmap
22+ from brainbox .io .spikeglx import stream
2123from brainbox .behavior .dlc import SAMPLING , plot_trace_on_frame , plot_wheel_position , plot_lick_hist , \
2224 plot_lick_raster , plot_motion_energy_hist , plot_speed_hist , plot_pupil_diameter_hist
2325from brainbox .ephys_plots import image_lfp_spectrum_plot , image_rms_plot , plot_brain_regions
2426from brainbox .io .one import load_spike_sorting_fast
2527from brainbox .behavior import training
2628from iblutil .numerical import ismember
29+ from ibllib .plots .misc import Density
2730
2831
2932logger = logging .getLogger ('ibllib' )
@@ -398,6 +401,14 @@ def _run(self):
398401 """runs for initiated PID, streams data, destripe and check bad channels"""
399402 assert self .pid
400403 self .eqcs = []
404+ T0 = 60 * 30
405+ SNAPSHOT_LABEL = "raw_ephys_bad_channels"
406+ output_files = list (self .output_directory .glob (f'{ SNAPSHOT_LABEL } *' ))
407+ if len (output_files ) == 4 :
408+ return output_files
409+
410+ self .output_directory .mkdir (exist_ok = True , parents = True )
411+
401412 if self .location != 'server' :
402413 self .histology_status = self .get_histology_status ()
403414 electrodes = self .get_channels ('electrodeSites' , f'alf/{ self .pname } ' )
@@ -406,74 +417,124 @@ def _run(self):
406417 electrodes ['ibr' ] = ismember (electrodes ['atlas_id' ], self .brain_regions .id )[1 ]
407418 electrodes ['acronym' ] = self .brain_regions .acronym [electrodes ['ibr' ]]
408419 electrodes ['name' ] = self .brain_regions .name [electrodes ['ibr' ]]
420+ electrodes ['title' ] = self .histology_status
409421 else :
410422 electrodes = None
423+
424+ sr , t0 = stream (self .pid , T0 , nsecs = 1 , one = self .one )
425+ raw = sr [:, :- sr .nsync ].T
411426 else :
412427 electrodes = None
428+ ap_file = next (self .session_path .joinpath ('raw_ephys_data' , self .pname ).glob ('*ap.*bin' ), None )
429+ if ap_file is not None :
430+ sr = spikeglx .Reader (ap_file )
431+ raw = sr [int ((sr .fs * T0 )):int ((sr .fs * (T0 + 1 ))), :- sr .nsync ].T
432+ else :
433+ return []
413434
414- SNAPSHOT_LABEL = "raw_ephys_bad_channels"
415- eid , pname = self .one .pid2eid (self .pid )
416- output_files = list (self .output_directory .glob (f'{ SNAPSHOT_LABEL } *' ))
417- if len (output_files ) == 4 :
418- return output_files
419- self .output_directory .mkdir (exist_ok = True , parents = True )
420- from brainbox .io .spikeglx import stream
421- T0 = 60 * 30
422- sr , t0 = stream (self .pid , T0 , nsecs = 1 , one = self .one )
423- raw = sr [:, :- sr .nsync ].T
424435 channel_labels , channel_features = voltage .detect_bad_channels (raw , sr .fs )
425436 _ , eqcs , output_files = ephys_bad_channels (
426437 raw = raw , fs = sr .fs , channel_labels = channel_labels , channel_features = channel_features , channels = electrodes ,
427- title = SNAPSHOT_LABEL , destripe = True , save_dir = self .output_directory , br = self .brain_regions )
438+ title = SNAPSHOT_LABEL , destripe = True , save_dir = self .output_directory , br = self .brain_regions , pid_info = self . pid_label )
428439 self .eqcs = eqcs
429440 return output_files
430441
431442
432- def ephys_bad_channels (raw , fs , channel_labels , channel_features , channels = None , title = "ephys_bad_channels" , save_dir = None ,
433- destripe = False , eqcs = None , br = None ):
443+ def ephys_bad_channels (raw , fs , channel_labels , channel_features , channels = None , title = "ephys_bad_channels" ,
444+ save_dir = None , destripe = False , eqcs = None , br = None , pid_info = None , plot_backend = 'matplotlib' ):
434445 nc , ns = raw .shape
435446 rl = ns / fs
447+
448+ def gain2level (gain ):
449+ return 10 ** (gain / 20 ) * 4 * np .array ([- 1 , 1 ])
450+
436451 if fs >= 2600 : # AP band
437452 ylim_rms = [0 , 100 ]
438453 ylim_psd_hf = [0 , 0.1 ]
439454 eqc_xrange = [450 , 500 ]
440455 butter_kwargs = {'N' : 3 , 'Wn' : 300 / fs * 2 , 'btype' : 'highpass' }
441456 eqc_gain = - 90
457+ eqc_levels = gain2level (eqc_gain )
442458 else :
443459 # we are working with the LFP
444460 ylim_rms = [0 , 1000 ]
445461 ylim_psd_hf = [0 , 1 ]
446462 eqc_xrange = [450 , 950 ]
447463 butter_kwargs = {'N' : 3 , 'Wn' : np .array ([2 , 125 ]) / fs * 2 , 'btype' : 'bandpass' }
448464 eqc_gain = - 78
465+ eqc_levels = gain2level (eqc_gain )
449466
450467 inoisy = np .where (channel_labels == 2 )[0 ]
451468 idead = np .where (channel_labels == 1 )[0 ]
452469 ioutside = np .where (channel_labels == 3 )[0 ]
453- from viewspikes .gui import viewephys
454470
455471 # display voltage traces
456472 eqcs = [] if eqcs is None else eqcs
457473 # butterworth, for display only
458474 sos = scipy .signal .butter (** butter_kwargs , output = 'sos' )
459475 butt = scipy .signal .sosfiltfilt (sos , raw )
460- eqcs .append (viewephys (butt , fs = fs , channels = channels , title = 'highpass' , br = br ))
461- if destripe :
462- dest = voltage .destripe (raw , fs = fs , channel_labels = channel_labels )
463- eqcs .append (viewephys (dest , fs = fs , channels = channels , title = 'destripe' , br = br ))
464- eqcs .append (viewephys ((butt - dest ), fs = fs , channels = channels , title = 'difference' , br = br ))
465-
466- for eqc in eqcs :
467- y , x = np .meshgrid (ioutside , np .linspace (0 , rl * 1e3 , 500 ))
468- eqc .ctrl .add_scatter (x .flatten (), y .flatten (), rgb = (164 , 142 , 35 ), label = 'outside' )
469- y , x = np .meshgrid (inoisy , np .linspace (0 , rl * 1e3 , 500 ))
470- eqc .ctrl .add_scatter (x .flatten (), y .flatten (), rgb = (255 , 0 , 0 ), label = 'noisy' )
471- y , x = np .meshgrid (idead , np .linspace (0 , rl * 1e3 , 500 ))
472- eqc .ctrl .add_scatter (x .flatten (), y .flatten (), rgb = (0 , 0 , 255 ), label = 'dead' )
476+
477+ if plot_backend == 'matplotlib' :
478+ _ , axs = plt .subplots (1 , 2 , gridspec_kw = {'width_ratios' : [.95 , .05 ]}, figsize = (16 , 9 ))
479+ eqcs .append (Density (butt , fs = fs , taxis = 1 , ax = axs [0 ], title = 'highpass' , vmin = eqc_levels [0 ], vmax = eqc_levels [1 ],
480+ cmap = 'Greys' ))
481+
482+ if destripe :
483+ dest = voltage .destripe (raw , fs = fs , channel_labels = channel_labels )
484+ _ , axs = plt .subplots (1 , 2 , gridspec_kw = {'width_ratios' : [.95 , .05 ]}, figsize = (16 , 9 ))
485+ eqcs .append (Density (dest , fs = fs , taxis = 1 , ax = axs [0 ], title = 'destripe' , vmin = eqc_levels [0 ], vmax = eqc_levels [1 ],
486+ cmap = 'Greys' ))
487+ _ , axs = plt .subplots (1 , 2 , gridspec_kw = {'width_ratios' : [.95 , .05 ]}, figsize = (16 , 9 ))
488+ eqcs .append (Density ((butt - dest ), fs = fs , taxis = 1 , ax = axs [0 ], title = 'difference' , vmin = eqc_levels [0 ],
489+ vmax = eqc_levels [1 ], cmap = 'Greys' ))
490+
491+ for eqc in eqcs :
492+ y , x = np .meshgrid (ioutside , np .linspace (0 , rl * 1e3 , 500 ))
493+ eqc .ax .scatter (x .flatten (), y .flatten (), c = 'goldenrod' , s = 4 )
494+ y , x = np .meshgrid (inoisy , np .linspace (0 , rl * 1e3 , 500 ))
495+ eqc .ax .scatter (x .flatten (), y .flatten (), c = 'r' , s = 4 )
496+ y , x = np .meshgrid (idead , np .linspace (0 , rl * 1e3 , 500 ))
497+ eqc .ax .scatter (x .flatten (), y .flatten (), c = 'b' , s = 4 )
498+
499+ eqc .ax .set_xlim (* eqc_xrange )
500+ eqc .ax .set_ylim (0 , nc )
501+ eqc .ax .set_ylabel ('Channel index' )
502+ eqc .ax .set_title (f'{ pid_info } _{ eqc .title } ' )
503+ set_axis_label_size (eqc .ax )
504+
505+ ax = eqc .figure .axes [1 ]
506+ if channels is not None :
507+ chn_title = channels .get ('title' , None )
508+ plot_brain_regions (channels ['atlas_id' ], brain_regions = br , display = True , ax = ax ,
509+ title = chn_title )
510+ set_axis_label_size (ax )
511+ else :
512+ remove_axis_outline (ax )
513+ else :
514+ from viewspikes .gui import viewephys # noqa
515+ eqcs .append (viewephys (butt , fs = fs , channels = channels , title = 'highpass' , br = br ))
516+
517+ if destripe :
518+ dest = voltage .destripe (raw , fs = fs , channel_labels = channel_labels )
519+ eqcs .append (viewephys (dest , fs = fs , channels = channels , title = 'destripe' , br = br ))
520+ eqcs .append (viewephys ((butt - dest ), fs = fs , channels = channels , title = 'difference' , br = br ))
521+
522+ for eqc in eqcs :
523+ y , x = np .meshgrid (ioutside , np .linspace (0 , rl * 1e3 , 500 ))
524+ eqc .ctrl .add_scatter (x .flatten (), y .flatten (), rgb = (164 , 142 , 35 ), label = 'outside' )
525+ y , x = np .meshgrid (inoisy , np .linspace (0 , rl * 1e3 , 500 ))
526+ eqc .ctrl .add_scatter (x .flatten (), y .flatten (), rgb = (255 , 0 , 0 ), label = 'noisy' )
527+ y , x = np .meshgrid (idead , np .linspace (0 , rl * 1e3 , 500 ))
528+ eqc .ctrl .add_scatter (x .flatten (), y .flatten (), rgb = (0 , 0 , 255 ), label = 'dead' )
529+
530+ eqcs [0 ].ctrl .set_gain (eqc_gain )
531+ eqcs [0 ].resize (1960 , 1200 )
532+ eqcs [0 ].viewBox_seismic .setXRange (* eqc_xrange )
533+ eqcs [0 ].viewBox_seismic .setYRange (0 , nc )
534+ eqcs [0 ].ctrl .propagate ()
535+
473536 # display features
474537 fig , axs = plt .subplots (2 , 2 , sharex = True , figsize = [16 , 9 ], tight_layout = True )
475-
476- # fig.suptitle(f"pid:{pid}, \n eid:{eid}, \n {one.eid2path(eid).parts[-3:]}, {pname}")
477538 fig .suptitle (title )
478539 axs [0 , 0 ].plot (channel_features ['rms_raw' ] * 1e6 )
479540 axs [0 , 0 ].set (title = 'rms' , xlabel = 'channel number' , ylabel = 'rms (uV)' , ylim = ylim_rms )
@@ -499,18 +560,16 @@ def ephys_bad_channels(raw, fs, channel_labels, channel_features, channels=None,
499560 axs [1 , 1 ].plot (inoisy , inoisy * 0 + fs / 4 , 'xr' )
500561 axs [1 , 1 ].plot (ioutside , ioutside * 0 + fs / 4 , 'xy' )
501562
502- eqcs [0 ].ctrl .set_gain (eqc_gain )
503- eqcs [0 ].resize (1960 , 1200 )
504- eqcs [0 ].viewBox_seismic .setXRange (* eqc_xrange )
505- eqcs [0 ].viewBox_seismic .setYRange (0 , nc )
506- eqcs [0 ].ctrl .propagate ()
507-
508563 if save_dir is not None :
509564 output_files = [Path (save_dir ).joinpath (f"{ title } .png" )]
510565 fig .savefig (output_files [0 ])
511566 for eqc in eqcs :
512- output_files .append (Path (save_dir ).joinpath (f"{ title } _{ eqc .windowTitle ()} .png" ))
513- eqc .grab ().save (str (output_files [- 1 ]))
567+ if plot_backend == 'matplotlib' :
568+ output_files .append (Path (save_dir ).joinpath (f"{ title } _{ eqc .title } .png" ))
569+ eqc .figure .savefig (str (output_files [- 1 ]))
570+ else :
571+ output_files .append (Path (save_dir ).joinpath (f"{ title } _{ eqc .windowTitle ()} .png" ))
572+ eqc .grab ().save (str (output_files [- 1 ]))
514573 return fig , eqcs , output_files
515574 else :
516575 return fig , eqcs
0 commit comments