@@ -34,31 +34,40 @@ class SpikeSorting(ReportSnapshotProbe):
3434 :param **kwargs: keyword arguments passed to tasks.Task
3535 """
3636
37- def _run (self ):
37+ def _run (self , collection = None ):
3838 """runs for initiated PID, streams data, destripe and check bad channels"""
3939 all_here , output_files = self .assert_expected (self .output_files , silent = True )
40- if all_here :
40+ spike_sorting_runs = self .one .list_datasets (self .eid , filename = 'spikes.times.npy' , collection = f'alf/{ self .pname } *' )
41+ if all_here and len (output_files ) == len (spike_sorting_runs ):
4142 return output_files
42- spikes , clusters , channels = load_spike_sorting_fast (
43- eid = self .eid , probe = self .pname , one = self .one , nested = False ,
44- dataset_types = ['spikes.depths' ], brain_regions = self .brain_regions )
45- fig , axs = plt .subplots (1 , 2 , gridspec_kw = {'width_ratios' : [.95 , .05 ]}, sharey = True , figsize = (16 , 9 ))
46- driftmap (spikes .times , spikes .depths , t_bin = 0.007 , d_bin = 10 , vmax = 0.5 , ax = axs [0 ])
47- if 'atlas_id' in channels .keys ():
48- plot_brain_regions (channels ['atlas_id' ], channel_depths = channels ['axial_um' ],
49- brain_regions = None , display = True , ax = axs [1 ])
50- title_str = f"{ self .pid_label } , { self .pid } , { spikes .clusters .size :_} spikes, { clusters .depths .size :_} clusters"
51- axs [0 ].set (ylim = [0 , 3800 ], title = title_str )
52- output_files = [self .output_directory .joinpath ("spike_sorting_raster.png" )]
53- fig .savefig (output_files [0 ])
54- plt .close (fig )
43+ logger .info (self .output_directory )
44+ output_files = []
45+ for run in spike_sorting_runs :
46+ collection = str (Path (run ).parent )
47+ spikes , clusters , channels = load_spike_sorting_fast (
48+ eid = self .eid , probe = self .pname , one = self .one , nested = False , collection = collection ,
49+ dataset_types = ['spikes.depths' ], brain_regions = self .brain_regions )
50+
51+ fig , axs = plt .subplots (1 , 2 , gridspec_kw = {'width_ratios' : [.95 , .05 ]}, sharey = True , figsize = (16 , 9 ))
52+ driftmap (spikes .times , spikes .depths , t_bin = 0.007 , d_bin = 10 , vmax = 0.5 , ax = axs [0 ])
53+ if 'atlas_id' in channels .keys ():
54+ plot_brain_regions (channels ['atlas_id' ], channel_depths = channels ['axial_um' ],
55+ brain_regions = None , display = True , ax = axs [1 ])
56+ title_str = f"{ self .pid_label } , { collection } , { self .pid } \n { spikes .clusters .size :_} spikes, { clusters .depths .size :_} clusters"
57+ logger .info (title_str .replace ("\n " , "" ))
58+ axs [0 ].set (ylim = [0 , 3800 ], title = title_str )
59+ run_label = str (Path (collection ).relative_to (f'alf/{ self .pname } ' ))
60+ run_label = "" if run_label == '.' else run_label
61+ output_files .append (self .output_directory .joinpath (f"spike_sorting_raster_{ run_label } .png" ))
62+ fig .savefig (output_files [- 1 ])
63+ plt .close (fig )
5564 return output_files
5665
5766 def get_probe_signature (self ):
5867 input_signature = [('spikes.times.npy' , f'alf/{ self .pname } ' , True ),
5968 ('spikes.amps.npy' , f'alf/{ self .pname } ' , True ),
6069 ('spikes.depths.npy' , f'alf/{ self .pname } ' , True )]
61- output_signature = [('spike_sorting_raster.png' , f'snapshot/{ self .pname } ' , True )]
70+ output_signature = [('spike_sorting_raster* .png' , f'snapshot/{ self .pname } ' , True )]
6271 self .signature = {'input_files' : input_signature , 'output_files' : output_signature }
6372
6473
0 commit comments