Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 14 additions & 101 deletions kilosort/gui/sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import kilosort
from kilosort.run_kilosort import (
setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction,
detect_spikes, cluster_spikes, save_sorting, close_logger
# setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction,
# detect_spikes, cluster_spikes, save_sorting, close_logger
setup_logger, _sort
)
from kilosort.io import save_preprocessing
from kilosort.utils import (
Expand Down Expand Up @@ -46,105 +47,17 @@ def run(self):
results_dir.mkdir(parents=True)

setup_logger(results_dir)
verbose = settings['verbose_log']

try:
logger.info(f"Kilosort version {kilosort.__version__}")
logger.info(f"Sorting {self.data_path}")
clear_cache = settings['clear_cache']
if clear_cache:
logger.info('clear_cache=True')
logger.info('-'*40)

tic0 = time.time()

if probe['chanMap'].max() >= settings['n_chan_bin']:
raise ValueError(
f'Largest value of chanMap exceeds channel count of data, '
'make sure chanMap is 0-indexed.'
)

if settings['nt0min'] is None:
settings['nt0min'] = int(20 * settings['nt']/61)
data_dtype = settings['data_dtype']
device = self.device
save_preprocessed_copy = settings['save_preprocessed_copy']
do_CAR = settings['do_CAR']
invert_sign = settings['invert_sign']
if not do_CAR:
logger.info("Skipping common average reference.")

ops = initialize_ops(settings, probe, data_dtype, do_CAR,
invert_sign, device, save_preprocessed_copy)

# Pretty-print ops and probe for log
logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n")
logger.debug(f"Probe dictionary:\n\n{probe_as_string(ops['probe'])}\n")

# TODO: add support for file object through data conversion
# Set preprocessing and drift correction parameters
ops = compute_preprocessing(ops, self.device, tic0=tic0,
file_object=self.file_object)
np.random.seed(1)
torch.cuda.manual_seed_all(1)
torch.random.manual_seed(1)
ops, bfile, st0 = compute_drift_correction(
ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
file_object=self.file_object, clear_cache=clear_cache
)

# Check scale of data for log file
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")

if save_preprocessed_copy:
save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)

# Will be None if nblocks = 0 (no drift correction)
if st0 is not None:
self.dshift = ops['dshift']
self.st0 = st0
self.plotDataReady.emit('drift')

# Sort spikes and save results
st, tF, Wall0, clu0 = detect_spikes(
ops, self.device, bfile, tic0=tic0,
progress_bar=self.progress_bar, clear_cache=clear_cache,
verbose=verbose
)

self.Wall0 = Wall0
self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
self.clu0 = clu0
self.plotDataReady.emit('diagnostics')

clu, Wall = cluster_spikes(
st, tF, ops, self.device, bfile, tic0=tic0,
progress_bar=self.progress_bar, clear_cache=clear_cache,
verbose=verbose
)
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)

except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
logger.exception('Out of memory error, printing performance...')
log_performance(logger, level='info')
log_cuda_details(logger)
# This makes sure the full traceback is written to log file.
logger.exception('Encountered error in `run_kilosort`:')
# Annoyingly, this will print the error message twice for console
# but I haven't found a good way around that.
raise

finally:
close_logger()

self.ops = ops
self.st = st[kept_spikes]
self.clu = clu[kept_spikes]
self.tF = tF[kept_spikes]
self.is_refractory = is_ref
self.plotDataReady.emit('probe')
# NOTE: All but `gui_sorter` are positional args,
# don't move these around.
_ = _sort(
settings['filename'], results_dir, probe, settings,
settings['data_dtype'], self.device, settings['do_CAR'],
settings['clear_cache'], settings['invert_sign'],
settings['save_preprocessed_copy'], settings['verbose_log'],
False, self.file_object, self.progress_bar, gui_sorter=self
)
# Hard-coded `False` is for "save_extra_vars", which isn't an option
# in the GUI right now (and isn't likely to be added).

self.finishedSpikesort.emit(self.context)
82 changes: 71 additions & 11 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,26 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
settings = {**DEFAULT_SETTINGS, **settings}
# NOTE: This modifies settings in-place
filename, data_dir, results_dir, probe = \
set_files(settings, filename, probe, probe_name, data_dir, results_dir, bad_channels)
set_files(settings, filename, probe, probe_name, data_dir,
results_dir, bad_channels)
setup_logger(results_dir, verbose_console=verbose_console)

ops, st, clu, tF, Wall, similar_templates, \
is_ref, est_contam_rate, kept_spikes = _sort(
filename, results_dir, probe, settings, data_dtype, device, do_CAR,
clear_cache, invert_sign, save_preprocessed_copy, verbose_log,
save_extra_vars, file_object, progress_bar
)

return ops, st, clu, tF, Wall, similar_templates, \
is_ref, est_contam_rate, kept_spikes


def _sort(filename, results_dir, probe, settings, data_dtype, device, do_CAR,
clear_cache, invert_sign, save_preprocessed_copy, verbose_log,
save_extra_vars, file_object, progress_bar, gui_sorter=None):
"""Run sorting pipeline. See `run_kilosort` for documentation."""

try:
logger.info(f"Kilosort version {kilosort.__version__}")
logger.info(f"Python version {platform.python_version()}")
Expand Down Expand Up @@ -218,7 +235,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,

tic0 = time.time()
ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
device, save_preprocessed_copy)
device, save_preprocessed_copy)

# Pretty-print ops and probe for log
logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n")
Expand All @@ -242,24 +259,62 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")

# Save preprocessing steps
if save_preprocessed_copy:
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)

# Generate drift plots
# st0 will be None if nblocks = 0 (no drift correction)
if st0 is not None:
if gui_sorter is not None:
gui_sorter.dshift = ops['dshift']
gui_sorter.st0 = st0
gui_sorter.plotDataReady.emit('drift')
else:
# TODO: save non-GUI version of plot to results.
pass

# Sort spikes and save results
st,tF, _, _ = detect_spikes(
st,tF, Wall0, clu0 = detect_spikes(
ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
clear_cache=clear_cache, verbose=verbose_log
)
clu, Wall = cluster_spikes(

# Generate diagnosic plots
if gui_sorter is not None:
gui_sorter.Wall0 = Wall0
gui_sorter.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
gui_sorter.clu0 = clu0
gui_sorter.plotDataReady.emit('diagnostics')
else:
# TODO: save non-GUI version of plot to results.
pass

clu, Wall, st, tF = cluster_spikes(
st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
clear_cache=clear_cache, verbose=verbose_log
clear_cache=clear_cache, verbose=verbose_log,
)
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(
ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
save_extra_vars=save_extra_vars,
save_preprocessed_copy=save_preprocessed_copy
)

# Generate spike positions plot
if gui_sorter is not None:
# TODO: re-use spike positions saved by `save_sorting` instead of
# computing them again in `kilosort.gui.sanity_plots`.
gui_sorter.ops = ops
gui_sorter.st = st[kept_spikes]
gui_sorter.clu = clu[kept_spikes]
gui_sorter.tF = tF[kept_spikes]
gui_sorter.is_refractory = is_ref
gui_sorter.plotDataReady.emit('probe')
else:
# TODO: save non-GUI version of plot to results.
pass

except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
logger.exception('Out of memory error, printing performance...')
Expand All @@ -275,6 +330,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
finally:
close_logger()


return ops, st, clu, tF, Wall, similar_templates, \
is_ref, est_contam_rate, kept_spikes

Expand Down Expand Up @@ -676,7 +732,9 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
ops, st0, tF, mode='spikes', device=device, progress_bar=progress_bar,
clear_cache=clear_cache, verbose=verbose
)
Wall3 = template_matching.postprocess_templates(Wall, ops, clu, st0, device=device)
Wall3 = template_matching.postprocess_templates(
Wall, ops, clu, st0, tF, device=device
)
logger.info(f'{clu.max()+1} clusters found, in {time.time()-tic : .2f}s; ' +
f'total {time.time()-tic0 : .2f}s')
logger.debug(f'clu shape: {clu.shape}')
Expand All @@ -686,8 +744,9 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
logger.info(' ')
logger.info('Extracting spikes using cluster waveforms')
logger.info('-'*40)
st, tF, ops = template_matching.extract(ops, bfile, Wall3, device=device,
progress_bar=progress_bar)
st, tF, ops = template_matching.extract(
ops, bfile, Wall3, device=device, progress_bar=progress_bar
)
logger.info(f'{len(st)} spikes extracted in {time.time()-tic : .2f}s; ' +
f'total {time.time()-tic0 : .2f}s')
logger.debug(f'st shape: {st.shape}')
Expand Down Expand Up @@ -756,8 +815,9 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None,
logger.info(' ')
logger.info('Merging clusters')
logger.info('-'*40)
Wall, clu, is_ref = template_matching.merging_function(ops, Wall, clu, st[:,0],
device=device)
Wall, clu, is_ref, st, tF = template_matching.merging_function(
ops, Wall, clu, st, tF, device=device, check_dt=True
)
clu = clu.astype('int32')
logger.info(f'{clu.max()+1} units found, in {time.time()-tic : .2f}s; ' +
f'total {time.time()-tic0 : .2f}s')
Expand All @@ -767,7 +827,7 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None,
log_performance(logger, 'info', 'Resource usage after clustering')
log_cuda_details(logger)

return clu, Wall
return clu, Wall, st, tF


def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,
Expand Down
58 changes: 48 additions & 10 deletions kilosort/template_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,13 @@ def align_U(U, ops, device=torch.device('cuda')):
return Unew, imax


def postprocess_templates(Wall, ops, clu, st, device=torch.device('cuda')):
def postprocess_templates(Wall, ops, clu, st, tF, device=torch.device('cuda')):
Wall2, _ = align_U(Wall, ops, device=device)
#Wall3, _= remove_duplicates(ops, Wall2)
Wall3, _, _ = merging_function(ops, Wall2.transpose(1,2), clu, st[:,0],
0.9, 'mu', device=device)
Wall3, _, _, _, _ = merging_function(
ops, Wall2.transpose(1,2), clu, st, tF,
0.9, 'mu', check_dt=False, device=device
)
Wall3 = Wall3.transpose(1,2).to(device)
return Wall3

Expand Down Expand Up @@ -241,7 +243,8 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
return st, amps, th_amps, Xres


def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.device('cuda')):
def merging_function(ops, Wall, clu, st, tF, r_thresh=0.5, mode='ccg', check_dt=True,
device=torch.device('cuda')):
clu2 = clu.copy()
clu_unq, ns = np.unique(clu2, return_counts = True)

Expand All @@ -256,7 +259,7 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.
acg_threshold = ops['settings']['acg_threshold']
ccg_threshold = ops['settings']['ccg_threshold']
if mode == 'ccg':
is_ref, est_contam_rate = CCG.refract(clu, st/ops['fs'],
is_ref, est_contam_rate = CCG.refract(clu, st[:,0]/ops['fs'],
acg_threshold=acg_threshold,
ccg_threshold=ccg_threshold)

Expand Down Expand Up @@ -287,13 +290,13 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.
UtU = torch.einsum('lk, jlm -> jkm', Wnorm[kk], Wnorm)
ctc = torch.einsum('jkm, kml -> jl', UtU, WtW)

cmax = ctc.max(1)[0]
cmax, imax = ctc.max(1)
cmax[kk] = 0

jsort = np.argsort(cmax.cpu().numpy())[::-1]

if mode == 'ccg':
st0 = st[clu2==kk] / ops['fs']
st0 = st[:,0][clu2==kk] / ops['fs']

is_ccg = 0
for j in range(NN):
Expand All @@ -302,7 +305,7 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.
break
# compare with CCG
if mode == 'ccg':
st1 = st[clu2==jj] / ops['fs']
st1 = st[:,0][clu2==jj] / ops['fs']
_, is_ccg, _ = CCG.check_CCG(st0, st1, acg_threshold=acg_threshold,
ccg_threshold=ccg_threshold)
else:
Expand All @@ -311,9 +314,17 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.

if is_ccg:
is_merged[jj] = 1
dt = (imax[kk] -imax[jj]).item()
if dt != 0 and check_dt:
# Get spike indices for cluster jj
idx = (clu2 == jj)
# Update tF and Wall with shifted features
tF, Wall = roll_features(W, tF, Ww, idx, jj, dt)
# Shift spike times
st[idx,0] -= dt

Ww[kk] = ns[kk]/(ns[kk]+ns[jj]) * Ww[kk] + ns[jj]/(ns[kk]+ns[jj]) * Ww[jj]
Ww[jj] = 0

ns[kk] += ns[jj]
ns[jj] = 0
clu2[clu2==jj] = kk
Expand All @@ -337,4 +348,31 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.
else:
is_ref = None

return Ww.cpu(), clu2, is_ref
sorted_idx = np.argsort(st[:,0])
st = np.take_along_axis(st, sorted_idx[..., np.newaxis], axis=0)
clu2 = clu2[sorted_idx]
tensor_idx = torch.from_numpy(sorted_idx)
tF = tF[tensor_idx]

return Ww.cpu(), clu2, is_ref, st, tF


def roll_features(wPCA, tF, Wall, spike_idx, clust_idx, dt):
W = wPCA.cpu()
# Project from PC space back to sample time, shift by dt
feats = torch.roll(tF[spike_idx] @ W, shifts=dt, dims=2)
temps = torch.roll(Wall[clust_idx:clust_idx+1] @ wPCA, shifts=dt, dims=2)

# For values that "rolled over the edge," set equal to next closest bin
if dt > 0:
feats[:,:,:dt] = feats[:,:,dt].unsqueeze(-1)
temps[:,:,:dt] = temps[:,:,dt].unsqueeze(-1)
elif dt < 0:
feats[:,:,dt:] = feats[:,:,dt-1].unsqueeze(-1)
temps[:,:,dt:] = temps[:,:,dt-1].unsqueeze(-1)

# Project back to PC space and update tF
tF[spike_idx] = feats @ W.T
Wall[clust_idx] = temps @ wPCA.T

return tF, Wall