1616import one .alf .io as alfio
1717from one .alf .exceptions import ALFObjectNotFound
1818from ibllib .io .video import get_video_frame , url_from_eid
19+ from brainbox .plot import driftmap
1920from brainbox .behavior .dlc import SAMPLING , plot_trace_on_frame , plot_wheel_position , plot_lick_hist , \
2021 plot_lick_raster , plot_motion_energy_hist , plot_speed_hist , plot_pupil_diameter_hist
22+ from brainbox .io .one import load_spike_sorting_fast
23+ from brainbox .ephys_plots import plot_brain_regions
24+
2125
2226logger = logging .getLogger ('ibllib' )
2327
@@ -32,7 +36,23 @@ class SpikeSorting(ReportSnapshotProbe):
3236
3337 def _run (self ):
3438 """runs for initiated PID, streams data, destripe and check bad channels"""
35- assert self .pid
39+ all_here , output_files = self .assert_expected (self .output_files , silent = True )
40+ if all_here :
41+ return output_files
42+ spikes , clusters , channels = load_spike_sorting_fast (
43+ eid = self .eid , probe = self .pname , one = self .one , nested = False ,
44+ dataset_types = ['spikes.depths' ], brain_regions = self .brain_regions )
45+ fig , axs = plt .subplots (1 , 2 , gridspec_kw = {'width_ratios' : [.95 , .05 ]}, sharey = True , figsize = (16 , 9 ))
46+ driftmap (spikes .times , spikes .depths , t_bin = 0.007 , d_bin = 10 , vmax = 0.5 , ax = axs [0 ])
47+ if 'atlas_id' in channels .keys ():
48+ plot_brain_regions (channels ['atlas_id' ], channel_depths = channels ['axial_um' ],
49+ brain_regions = None , display = True , ax = axs [1 ])
50+ title_str = f"{ self .pid_label } , { self .pid } , { spikes .clusters .size :_} spikes, { clusters .depths .size :_} clusters"
51+ axs [0 ].set (ylim = [0 , 3800 ], title = title_str )
52+ output_files = [self .output_directory .joinpath ("spike_sorting_raster.png" )]
53+ fig .savefig (output_files [0 ])
54+ plt .close (fig )
55+ return output_files
3656
3757 def get_probe_signature (self ):
3858 input_signature = [('spikes.times.npy' , f'alf/{ self .pname } ' , True ),
@@ -55,7 +75,7 @@ def get_probe_signature(self):
5575 pname = self .pname
5676 input_signature = [('*ap.meta' , f'raw_ephys_data/{ pname } ' , True ),
5777 ('*ap.ch' , f'raw_ephys_data/{ pname } ' , False )]
58- # ('*ap.cbin', f'raw_ephys_data/{pname}', False)]
78+ # ('*ap.cbin', f'raw_ephys_data/{pname}', False)]
5979 output_signature = [('raw_ephys_bad_channels.png' , f'snapshot/{ pname } ' , True ),
6080 ('raw_ephys_bad_channels_highpass.png' , f'snapshot/{ pname } ' , True ),
6181 ('raw_ephys_bad_channels_highpass.png' , f'snapshot/{ pname } ' , True ),
@@ -69,19 +89,18 @@ def _run(self):
6989 assert self .pid
7090 SNAPSHOT_LABEL = "raw_ephys_bad_channels"
7191 eid , pname = self .one .pid2eid (self .pid )
72- output_directory = self .session_path .joinpath ('snapshot' , pname )
73- output_files = list (output_directory .glob (f'{ SNAPSHOT_LABEL } *' ))
92+ output_files = list (self .output_directory .glob (f'{ SNAPSHOT_LABEL } *' ))
7493 if len (output_files ) == 4 :
7594 return output_files
76- output_directory .mkdir (exist_ok = True , parents = True )
95+ self . output_directory .mkdir (exist_ok = True , parents = True )
7796 from brainbox .io .spikeglx import stream
7897 T0 = 60 * 30
7998 sr , t0 = stream (self .pid , T0 , nsecs = 1 , one = self .one )
8099 raw = sr [:, :- sr .nsync ].T
81100 channel_labels , channel_features = voltage .detect_bad_channels (raw , sr .fs )
82101 _ , _ , output_files = ephys_bad_channels (
83102 raw = raw , fs = sr .fs , channel_labels = channel_labels , channel_features = channel_features ,
84- title = SNAPSHOT_LABEL , destripe = True , save_dir = output_directory )
103+ title = SNAPSHOT_LABEL , destripe = True , save_dir = self . output_directory )
85104 return output_files
86105
87106
0 commit comments