Skip to content

Commit 99fbc0f

Browse files
committed
figures spike sorting: run each revision
1 parent 09e6462 commit 99fbc0f

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

ibllib/plots/figures.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,31 +34,40 @@ class SpikeSorting(ReportSnapshotProbe):
3434
:param **kwargs: keyword arguments passed to tasks.Task
3535
"""
3636

37-
def _run(self):
37+
def _run(self, collection=None):
3838
"""runs for initiated PID, streams data, destripe and check bad channels"""
3939
all_here, output_files = self.assert_expected(self.output_files, silent=True)
40-
if all_here:
40+
spike_sorting_runs = self.one.list_datasets(self.eid, filename='spikes.times.npy', collection=f'alf/{self.pname}*')
41+
if all_here and len(output_files) == len(spike_sorting_runs):
4142
return output_files
42-
spikes, clusters, channels = load_spike_sorting_fast(
43-
eid=self.eid, probe=self.pname, one=self.one, nested=False,
44-
dataset_types=['spikes.depths'], brain_regions=self.brain_regions)
45-
fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, sharey=True, figsize=(16, 9))
46-
driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0])
47-
if 'atlas_id' in channels.keys():
48-
plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
49-
brain_regions=None, display=True, ax=axs[1])
50-
title_str = f"{self.pid_label}, {self.pid}, {spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters"
51-
axs[0].set(ylim=[0, 3800], title=title_str)
52-
output_files = [self.output_directory.joinpath("spike_sorting_raster.png")]
53-
fig.savefig(output_files[0])
54-
plt.close(fig)
43+
logger.info(self.output_directory)
44+
output_files = []
45+
for run in spike_sorting_runs:
46+
collection = str(Path(run).parent)
47+
spikes, clusters, channels = load_spike_sorting_fast(
48+
eid=self.eid, probe=self.pname, one=self.one, nested=False, collection=collection,
49+
dataset_types=['spikes.depths'], brain_regions=self.brain_regions)
50+
51+
fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, sharey=True, figsize=(16, 9))
52+
driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0])
53+
if 'atlas_id' in channels.keys():
54+
plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
55+
brain_regions=None, display=True, ax=axs[1])
56+
title_str = f"{self.pid_label}, {collection}, {self.pid} \n {spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters"
57+
logger.info(title_str.replace("\n", ""))
58+
axs[0].set(ylim=[0, 3800], title=title_str)
59+
run_label = str(Path(collection).relative_to(f'alf/{self.pname}'))
60+
run_label = "" if run_label == '.' else run_label
61+
output_files.append(self.output_directory.joinpath(f"spike_sorting_raster_{run_label}.png"))
62+
fig.savefig(output_files[-1])
63+
plt.close(fig)
5564
return output_files
5665

5766
def get_probe_signature(self):
5867
input_signature = [('spikes.times.npy', f'alf/{self.pname}', True),
5968
('spikes.amps.npy', f'alf/{self.pname}', True),
6069
('spikes.depths.npy', f'alf/{self.pname}', True)]
61-
output_signature = [('spike_sorting_raster.png', f'snapshot/{self.pname}', True)]
70+
output_signature = [('spike_sorting_raster*.png', f'snapshot/{self.pname}', True)]
6271
self.signature = {'input_files': input_signature, 'output_files': output_signature}
6372

6473

0 commit comments

Comments
 (0)