Skip to content

Commit 842f71f

Browse files
Moved sorting steps in run_kilosort to gui-compatible subroutine
1 parent e7d138b commit 842f71f

File tree

2 files changed

+173
-105
lines changed

2 files changed

+173
-105
lines changed

kilosort/gui/sorter.py

Lines changed: 114 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
import kilosort
1010
from kilosort.run_kilosort import (
11-
setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction,
12-
detect_spikes, cluster_spikes, save_sorting, close_logger
11+
# setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction,
12+
# detect_spikes, cluster_spikes, save_sorting, close_logger
13+
setup_logger, _sort
1314
)
1415
from kilosort.io import save_preprocessing
1516
from kilosort.utils import (
@@ -46,105 +47,116 @@ def run(self):
4647
results_dir.mkdir(parents=True)
4748

4849
setup_logger(results_dir)
49-
verbose = settings['verbose_log']
50-
51-
try:
52-
logger.info(f"Kilosort version {kilosort.__version__}")
53-
logger.info(f"Sorting {self.data_path}")
54-
clear_cache = settings['clear_cache']
55-
if clear_cache:
56-
logger.info('clear_cache=True')
57-
logger.info('-'*40)
58-
59-
tic0 = time.time()
60-
61-
if probe['chanMap'].max() >= settings['n_chan_bin']:
62-
raise ValueError(
63-
f'Largest value of chanMap exceeds channel count of data, '
64-
'make sure chanMap is 0-indexed.'
65-
)
66-
67-
if settings['nt0min'] is None:
68-
settings['nt0min'] = int(20 * settings['nt']/61)
69-
data_dtype = settings['data_dtype']
70-
device = self.device
71-
save_preprocessed_copy = settings['save_preprocessed_copy']
72-
do_CAR = settings['do_CAR']
73-
invert_sign = settings['invert_sign']
74-
if not do_CAR:
75-
logger.info("Skipping common average reference.")
76-
77-
ops = initialize_ops(settings, probe, data_dtype, do_CAR,
78-
invert_sign, device, save_preprocessed_copy)
79-
80-
# Pretty-print ops and probe for log
81-
logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n")
82-
logger.debug(f"Probe dictionary:\n\n{probe_as_string(ops['probe'])}\n")
83-
84-
# TODO: add support for file object through data conversion
85-
# Set preprocessing and drift correction parameters
86-
ops = compute_preprocessing(ops, self.device, tic0=tic0,
87-
file_object=self.file_object)
88-
np.random.seed(1)
89-
torch.cuda.manual_seed_all(1)
90-
torch.random.manual_seed(1)
91-
ops, bfile, st0 = compute_drift_correction(
92-
ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
93-
file_object=self.file_object, clear_cache=clear_cache
94-
)
95-
96-
# Check scale of data for log file
97-
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
98-
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")
99-
100-
if save_preprocessed_copy:
101-
save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
102-
103-
# Will be None if nblocks = 0 (no drift correction)
104-
if st0 is not None:
105-
self.dshift = ops['dshift']
106-
self.st0 = st0
107-
self.plotDataReady.emit('drift')
108-
109-
# Sort spikes and save results
110-
st, tF, Wall0, clu0 = detect_spikes(
111-
ops, self.device, bfile, tic0=tic0,
112-
progress_bar=self.progress_bar, clear_cache=clear_cache,
113-
verbose=verbose
114-
)
115-
116-
self.Wall0 = Wall0
117-
self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
118-
self.clu0 = clu0
119-
self.plotDataReady.emit('diagnostics')
120-
121-
clu, Wall, _ = cluster_spikes(
122-
st, tF, ops, self.device, bfile, tic0=tic0,
123-
progress_bar=self.progress_bar, clear_cache=clear_cache,
124-
verbose=verbose
125-
)
126-
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
127-
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)
128-
129-
except Exception as e:
130-
if isinstance(e, torch.cuda.OutOfMemoryError):
131-
logger.exception('Out of memory error, printing performance...')
132-
log_performance(logger, level='info')
133-
log_cuda_details(logger)
134-
# This makes sure the full traceback is written to log file.
135-
logger.exception('Encountered error in `run_kilosort`:')
136-
# Annoyingly, this will print the error message twice for console
137-
# but I haven't found a good way around that.
138-
raise
139-
140-
finally:
141-
close_logger()
142-
143-
self.ops = ops
144-
self.st = st[kept_spikes]
145-
self.clu = clu[kept_spikes]
146-
self.tF = tF[kept_spikes]
147-
self.is_refractory = is_ref
148-
self.plotDataReady.emit('probe')
50+
51+
# NOTE: All but `gui_sorter` are positional args,
52+
# don't move these around.
53+
_ = _sort(
54+
settings['filename'], results_dir, probe, settings,
55+
settings['data_dtype'], self.device, settings['do_CAR'],
56+
settings['clear_cache'], settings['invert_sign'],
57+
settings['save_preprocessed_copy'], settings['verbose_log'],
58+
False, self.file_object, self.progress_bar, gui_sorter=self
59+
)
60+
# Hard-coded `False` is for "save_extra_vars", which isn't an option
61+
# in the GUI right now (and isn't likely to be added).
62+
63+
# try:
64+
# logger.info(f"Kilosort version {kilosort.__version__}")
65+
# logger.info(f"Sorting {self.data_path}")
66+
# clear_cache = settings['clear_cache']
67+
# if clear_cache:
68+
# logger.info('clear_cache=True')
69+
# logger.info('-'*40)
70+
71+
# tic0 = time.time()
72+
73+
# if probe['chanMap'].max() >= settings['n_chan_bin']:
74+
# raise ValueError(
75+
# f'Largest value of chanMap exceeds channel count of data, '
76+
# 'make sure chanMap is 0-indexed.'
77+
# )
78+
79+
# if settings['nt0min'] is None:
80+
# settings['nt0min'] = int(20 * settings['nt']/61)
81+
# data_dtype = settings['data_dtype']
82+
# device = self.device
83+
# save_preprocessed_copy = settings['save_preprocessed_copy']
84+
# do_CAR = settings['do_CAR']
85+
# invert_sign = settings['invert_sign']
86+
# if not do_CAR:
87+
# logger.info("Skipping common average reference.")
88+
89+
# ops = initialize_ops(settings, probe, data_dtype, do_CAR,
90+
# invert_sign, device, save_preprocessed_copy)
91+
92+
# # Pretty-print ops and probe for log
93+
# logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n")
94+
# logger.debug(f"Probe dictionary:\n\n{probe_as_string(ops['probe'])}\n")
95+
96+
# # TODO: add support for file object through data conversion
97+
# # Set preprocessing and drift correction parameters
98+
# ops = compute_preprocessing(ops, self.device, tic0=tic0,
99+
# file_object=self.file_object)
100+
# np.random.seed(1)
101+
# torch.cuda.manual_seed_all(1)
102+
# torch.random.manual_seed(1)
103+
# ops, bfile, st0 = compute_drift_correction(
104+
# ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
105+
# file_object=self.file_object, clear_cache=clear_cache
106+
# )
107+
108+
# # Check scale of data for log file
109+
# b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
110+
# logger.debug(f"First batch min, max: {b1.min(), b1.max()}")
111+
112+
# if save_preprocessed_copy:
113+
# save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
114+
115+
# # Will be None if nblocks = 0 (no drift correction)
116+
# if st0 is not None:
117+
# self.dshift = ops['dshift']
118+
# self.st0 = st0
119+
# self.plotDataReady.emit('drift')
120+
121+
# # Sort spikes and save results
122+
# st, tF, Wall0, clu0 = detect_spikes(
123+
# ops, self.device, bfile, tic0=tic0,
124+
# progress_bar=self.progress_bar, clear_cache=clear_cache,
125+
# verbose=verbose
126+
# )
127+
128+
# self.Wall0 = Wall0
129+
# self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
130+
# self.clu0 = clu0
131+
# self.plotDataReady.emit('diagnostics')
132+
133+
# clu, Wall, _ = cluster_spikes(
134+
# st, tF, ops, self.device, bfile, tic0=tic0,
135+
# progress_bar=self.progress_bar, clear_cache=clear_cache,
136+
# verbose=verbose
137+
# )
138+
# ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
139+
# save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)
140+
141+
# except Exception as e:
142+
# if isinstance(e, torch.cuda.OutOfMemoryError):
143+
# logger.exception('Out of memory error, printing performance...')
144+
# log_performance(logger, level='info')
145+
# log_cuda_details(logger)
146+
# # This makes sure the full traceback is written to log file.
147+
# logger.exception('Encountered error in `run_kilosort`:')
148+
# # Annoyingly, this will print the error message twice for console
149+
# # but I haven't found a good way around that.
150+
# raise
151+
152+
# finally:
153+
# close_logger()
154+
155+
# self.ops = ops
156+
# self.st = st[kept_spikes]
157+
# self.clu = clu[kept_spikes]
158+
# self.tF = tF[kept_spikes]
159+
# self.is_refractory = is_ref
160+
# self.plotDataReady.emit('probe')
149161

150162
self.finishedSpikesort.emit(self.context)

kilosort/run_kilosort.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,26 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
170170
settings = {**DEFAULT_SETTINGS, **settings}
171171
# NOTE: This modifies settings in-place
172172
filename, data_dir, results_dir, probe = \
173-
set_files(settings, filename, probe, probe_name, data_dir, results_dir, bad_channels)
173+
set_files(settings, filename, probe, probe_name, data_dir,
174+
results_dir, bad_channels)
174175
setup_logger(results_dir, verbose_console=verbose_console)
175176

177+
ops, st, clu, tF, Wall, similar_templates, \
178+
is_ref, est_contam_rate, kept_spikes = _sort(
179+
filename, results_dir, probe, settings, data_dtype, device, do_CAR,
180+
clear_cache, invert_sign, save_preprocessed_copy, verbose_log,
181+
save_extra_vars, file_object, progress_bar
182+
)
183+
184+
return ops, st, clu, tF, Wall, similar_templates, \
185+
is_ref, est_contam_rate, kept_spikes
186+
187+
188+
def _sort(filename, results_dir, probe, settings, data_dtype, device, do_CAR,
189+
clear_cache, invert_sign, save_preprocessed_copy, verbose_log,
190+
save_extra_vars, file_object, progress_bar, gui_sorter=None):
191+
"""Run sorting pipeline. See `run_kilosort` for documentation."""
192+
176193
try:
177194
logger.info(f"Kilosort version {kilosort.__version__}")
178195
logger.info(f"Python version {platform.python_version()}")
@@ -218,7 +235,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
218235

219236
tic0 = time.time()
220237
ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
221-
device, save_preprocessed_copy)
238+
device, save_preprocessed_copy)
222239

223240
# Pretty-print ops and probe for log
224241
logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n")
@@ -242,14 +259,37 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
242259
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
243260
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")
244261

262+
# Save preprocessing steps
245263
if save_preprocessed_copy:
246264
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
247265

266+
# Generate drift plots
267+
# st0 will be None if nblocks = 0 (no drift correction)
268+
if st0 is not None:
269+
if gui_sorter is not None:
270+
gui_sorter.dshift = ops['dshift']
271+
gui_sorter.st0 = st0
272+
gui_sorter.plotDataReady.emit('drift')
273+
else:
274+
# TODO: save non-GUI version of plot to results.
275+
pass
276+
248277
# Sort spikes and save results
249-
st,tF, _, _ = detect_spikes(
278+
st,tF, Wall0, clu0 = detect_spikes(
250279
ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
251280
clear_cache=clear_cache, verbose=verbose_log
252281
)
282+
283+
# Generate diagnosic plots
284+
if gui_sorter is not None:
285+
gui_sorter.Wall0 = Wall0
286+
gui_sorter.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
287+
gui_sorter.clu0 = clu0
288+
gui_sorter.plotDataReady.emit('diagnostics')
289+
else:
290+
# TODO: save non-GUI version of plot to results.
291+
pass
292+
253293
clu, Wall, st, tF = cluster_spikes(
254294
st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
255295
clear_cache=clear_cache, verbose=verbose_log,
@@ -260,6 +300,21 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
260300
save_extra_vars=save_extra_vars,
261301
save_preprocessed_copy=save_preprocessed_copy
262302
)
303+
304+
# Generate spike positions plot
305+
if gui_sorter is not None:
306+
# TODO: re-use spike positions saved by `save_sorting` instead of
307+
# computing them again in `kilosort.gui.sanity_plots`.
308+
gui_sorter.ops = ops
309+
gui_sorter.st = st[kept_spikes]
310+
gui_sorter.clu = clu[kept_spikes]
311+
gui_sorter.tF = tF[kept_spikes]
312+
gui_sorter.is_refractory = is_ref
313+
gui_sorter.plotDataReady.emit('probe')
314+
else:
315+
# TODO: save non-GUI version of plot to results.
316+
pass
317+
263318
except Exception as e:
264319
if isinstance(e, torch.cuda.OutOfMemoryError):
265320
logger.exception('Out of memory error, printing performance...')
@@ -275,6 +330,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
275330
finally:
276331
close_logger()
277332

333+
278334
return ops, st, clu, tF, Wall, similar_templates, \
279335
is_ref, est_contam_rate, kept_spikes
280336

0 commit comments

Comments
 (0)