diff --git a/kilosort/gui/sorter.py b/kilosort/gui/sorter.py index d504066f..1e74aaa8 100644 --- a/kilosort/gui/sorter.py +++ b/kilosort/gui/sorter.py @@ -8,8 +8,9 @@ import kilosort from kilosort.run_kilosort import ( - setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction, - detect_spikes, cluster_spikes, save_sorting, close_logger + # setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction, + # detect_spikes, cluster_spikes, save_sorting, close_logger + setup_logger, _sort ) from kilosort.io import save_preprocessing from kilosort.utils import ( @@ -46,105 +47,17 @@ def run(self): results_dir.mkdir(parents=True) setup_logger(results_dir) - verbose = settings['verbose_log'] - try: - logger.info(f"Kilosort version {kilosort.__version__}") - logger.info(f"Sorting {self.data_path}") - clear_cache = settings['clear_cache'] - if clear_cache: - logger.info('clear_cache=True') - logger.info('-'*40) - - tic0 = time.time() - - if probe['chanMap'].max() >= settings['n_chan_bin']: - raise ValueError( - f'Largest value of chanMap exceeds channel count of data, ' - 'make sure chanMap is 0-indexed.' - ) - - if settings['nt0min'] is None: - settings['nt0min'] = int(20 * settings['nt']/61) - data_dtype = settings['data_dtype'] - device = self.device - save_preprocessed_copy = settings['save_preprocessed_copy'] - do_CAR = settings['do_CAR'] - invert_sign = settings['invert_sign'] - if not do_CAR: - logger.info("Skipping common average reference.") - - ops = initialize_ops(settings, probe, data_dtype, do_CAR, - invert_sign, device, save_preprocessed_copy) - - # Pretty-print ops and probe for log - logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n") - logger.debug(f"Probe dictionary:\n\n{probe_as_string(ops['probe'])}\n") - - # TODO: add support for file object through data conversion - # Set preprocessing and drift correction parameters - ops = compute_preprocessing(ops, self.device, tic0=tic0, - file_object=self.file_object) - np.random.seed(1) - torch.cuda.manual_seed_all(1) - torch.random.manual_seed(1) - ops, bfile, st0 = compute_drift_correction( - ops, self.device, tic0=tic0, progress_bar=self.progress_bar, - file_object=self.file_object, clear_cache=clear_cache - ) - - # Check scale of data for log file - b1 = bfile.padded_batch_to_torch(0).cpu().numpy() - logger.debug(f"First batch min, max: {b1.min(), b1.max()}") - - if save_preprocessed_copy: - save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile) - - # Will be None if nblocks = 0 (no drift correction) - if st0 is not None: - self.dshift = ops['dshift'] - self.st0 = st0 - self.plotDataReady.emit('drift') - - # Sort spikes and save results - st, tF, Wall0, clu0 = detect_spikes( - ops, self.device, bfile, tic0=tic0, - progress_bar=self.progress_bar, clear_cache=clear_cache, - verbose=verbose - ) - - self.Wall0 = Wall0 - self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy() - self.clu0 = clu0 - self.plotDataReady.emit('diagnostics') - - clu, Wall = cluster_spikes( - st, tF, ops, self.device, bfile, tic0=tic0, - progress_bar=self.progress_bar, clear_cache=clear_cache, - verbose=verbose - ) - ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ - save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0) - - except Exception as e: - if isinstance(e, torch.cuda.OutOfMemoryError): - logger.exception('Out of memory error, printing performance...') - log_performance(logger, level='info') - log_cuda_details(logger) - # This makes sure the full traceback is written to log file. - logger.exception('Encountered error in `run_kilosort`:') - # Annoyingly, this will print the error message twice for console - # but I haven't found a good way around that. - raise - - finally: - close_logger() - - self.ops = ops - self.st = st[kept_spikes] - self.clu = clu[kept_spikes] - self.tF = tF[kept_spikes] - self.is_refractory = is_ref - self.plotDataReady.emit('probe') + # NOTE: All but `gui_sorter` are positional args, + # don't move these around. + _ = _sort( + settings['filename'], results_dir, probe, settings, + settings['data_dtype'], self.device, settings['do_CAR'], + settings['clear_cache'], settings['invert_sign'], + settings['save_preprocessed_copy'], settings['verbose_log'], + False, self.file_object, self.progress_bar, gui_sorter=self + ) + # Hard-coded `False` is for "save_extra_vars", which isn't an option + # in the GUI right now (and isn't likely to be added). self.finishedSpikesort.emit(self.context) diff --git a/kilosort/run_kilosort.py b/kilosort/run_kilosort.py index 1715762a..bdc5833c 100644 --- a/kilosort/run_kilosort.py +++ b/kilosort/run_kilosort.py @@ -170,9 +170,26 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, settings = {**DEFAULT_SETTINGS, **settings} # NOTE: This modifies settings in-place filename, data_dir, results_dir, probe = \ - set_files(settings, filename, probe, probe_name, data_dir, results_dir, bad_channels) + set_files(settings, filename, probe, probe_name, data_dir, + results_dir, bad_channels) setup_logger(results_dir, verbose_console=verbose_console) + ops, st, clu, tF, Wall, similar_templates, \ + is_ref, est_contam_rate, kept_spikes = _sort( + filename, results_dir, probe, settings, data_dtype, device, do_CAR, + clear_cache, invert_sign, save_preprocessed_copy, verbose_log, + save_extra_vars, file_object, progress_bar + ) + + return ops, st, clu, tF, Wall, similar_templates, \ + is_ref, est_contam_rate, kept_spikes + + +def _sort(filename, results_dir, probe, settings, data_dtype, device, do_CAR, + clear_cache, invert_sign, save_preprocessed_copy, verbose_log, + save_extra_vars, file_object, progress_bar, gui_sorter=None): + """Run sorting pipeline. See `run_kilosort` for documentation.""" + try: logger.info(f"Kilosort version {kilosort.__version__}") logger.info(f"Python version {platform.python_version()}") @@ -218,7 +235,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, tic0 = time.time() ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign, - device, save_preprocessed_copy) + device, save_preprocessed_copy) # Pretty-print ops and probe for log logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n") @@ -242,17 +259,40 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, b1 = bfile.padded_batch_to_torch(0).cpu().numpy() logger.debug(f"First batch min, max: {b1.min(), b1.max()}") + # Save preprocessing steps if save_preprocessed_copy: io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile) + # Generate drift plots + # st0 will be None if nblocks = 0 (no drift correction) + if st0 is not None: + if gui_sorter is not None: + gui_sorter.dshift = ops['dshift'] + gui_sorter.st0 = st0 + gui_sorter.plotDataReady.emit('drift') + else: + # TODO: save non-GUI version of plot to results. + pass + # Sort spikes and save results - st,tF, _, _ = detect_spikes( + st,tF, Wall0, clu0 = detect_spikes( ops, device, bfile, tic0=tic0, progress_bar=progress_bar, clear_cache=clear_cache, verbose=verbose_log ) - clu, Wall = cluster_spikes( + + # Generate diagnosic plots + if gui_sorter is not None: + gui_sorter.Wall0 = Wall0 + gui_sorter.wPCA = torch.clone(ops['wPCA'].cpu()).numpy() + gui_sorter.clu0 = clu0 + gui_sorter.plotDataReady.emit('diagnostics') + else: + # TODO: save non-GUI version of plot to results. + pass + + clu, Wall, st, tF = cluster_spikes( st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar, - clear_cache=clear_cache, verbose=verbose_log + clear_cache=clear_cache, verbose=verbose_log, ) ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ save_sorting( @@ -260,6 +300,21 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, save_extra_vars=save_extra_vars, save_preprocessed_copy=save_preprocessed_copy ) + + # Generate spike positions plot + if gui_sorter is not None: + # TODO: re-use spike positions saved by `save_sorting` instead of + # computing them again in `kilosort.gui.sanity_plots`. + gui_sorter.ops = ops + gui_sorter.st = st[kept_spikes] + gui_sorter.clu = clu[kept_spikes] + gui_sorter.tF = tF[kept_spikes] + gui_sorter.is_refractory = is_ref + gui_sorter.plotDataReady.emit('probe') + else: + # TODO: save non-GUI version of plot to results. + pass + except Exception as e: if isinstance(e, torch.cuda.OutOfMemoryError): logger.exception('Out of memory error, printing performance...') @@ -275,6 +330,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, finally: close_logger() + return ops, st, clu, tF, Wall, similar_templates, \ is_ref, est_contam_rate, kept_spikes @@ -676,7 +732,9 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None, ops, st0, tF, mode='spikes', device=device, progress_bar=progress_bar, clear_cache=clear_cache, verbose=verbose ) - Wall3 = template_matching.postprocess_templates(Wall, ops, clu, st0, device=device) + Wall3 = template_matching.postprocess_templates( + Wall, ops, clu, st0, tF, device=device + ) logger.info(f'{clu.max()+1} clusters found, in {time.time()-tic : .2f}s; ' + f'total {time.time()-tic0 : .2f}s') logger.debug(f'clu shape: {clu.shape}') @@ -686,8 +744,9 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None, logger.info(' ') logger.info('Extracting spikes using cluster waveforms') logger.info('-'*40) - st, tF, ops = template_matching.extract(ops, bfile, Wall3, device=device, - progress_bar=progress_bar) + st, tF, ops = template_matching.extract( + ops, bfile, Wall3, device=device, progress_bar=progress_bar + ) logger.info(f'{len(st)} spikes extracted in {time.time()-tic : .2f}s; ' + f'total {time.time()-tic0 : .2f}s') logger.debug(f'st shape: {st.shape}') @@ -756,8 +815,9 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None, logger.info(' ') logger.info('Merging clusters') logger.info('-'*40) - Wall, clu, is_ref = template_matching.merging_function(ops, Wall, clu, st[:,0], - device=device) + Wall, clu, is_ref, st, tF = template_matching.merging_function( + ops, Wall, clu, st, tF, device=device, check_dt=True + ) clu = clu.astype('int32') logger.info(f'{clu.max()+1} units found, in {time.time()-tic : .2f}s; ' + f'total {time.time()-tic0 : .2f}s') @@ -767,7 +827,7 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None, log_performance(logger, 'info', 'Resource usage after clustering') log_cuda_details(logger) - return clu, Wall + return clu, Wall, st, tF def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan, diff --git a/kilosort/template_matching.py b/kilosort/template_matching.py index b4fbd15e..dca96c59 100644 --- a/kilosort/template_matching.py +++ b/kilosort/template_matching.py @@ -140,11 +140,13 @@ def align_U(U, ops, device=torch.device('cuda')): return Unew, imax -def postprocess_templates(Wall, ops, clu, st, device=torch.device('cuda')): +def postprocess_templates(Wall, ops, clu, st, tF, device=torch.device('cuda')): Wall2, _ = align_U(Wall, ops, device=device) #Wall3, _= remove_duplicates(ops, Wall2) - Wall3, _, _ = merging_function(ops, Wall2.transpose(1,2), clu, st[:,0], - 0.9, 'mu', device=device) + Wall3, _, _, _, _ = merging_function( + ops, Wall2.transpose(1,2), clu, st, tF, + 0.9, 'mu', check_dt=False, device=device + ) Wall3 = Wall3.transpose(1,2).to(device) return Wall3 @@ -241,7 +243,8 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')): return st, amps, th_amps, Xres -def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.device('cuda')): +def merging_function(ops, Wall, clu, st, tF, r_thresh=0.5, mode='ccg', check_dt=True, + device=torch.device('cuda')): clu2 = clu.copy() clu_unq, ns = np.unique(clu2, return_counts = True) @@ -256,7 +259,7 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch. acg_threshold = ops['settings']['acg_threshold'] ccg_threshold = ops['settings']['ccg_threshold'] if mode == 'ccg': - is_ref, est_contam_rate = CCG.refract(clu, st/ops['fs'], + is_ref, est_contam_rate = CCG.refract(clu, st[:,0]/ops['fs'], acg_threshold=acg_threshold, ccg_threshold=ccg_threshold) @@ -287,13 +290,13 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch. UtU = torch.einsum('lk, jlm -> jkm', Wnorm[kk], Wnorm) ctc = torch.einsum('jkm, kml -> jl', UtU, WtW) - cmax = ctc.max(1)[0] + cmax, imax = ctc.max(1) cmax[kk] = 0 jsort = np.argsort(cmax.cpu().numpy())[::-1] if mode == 'ccg': - st0 = st[clu2==kk] / ops['fs'] + st0 = st[:,0][clu2==kk] / ops['fs'] is_ccg = 0 for j in range(NN): @@ -302,7 +305,7 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch. break # compare with CCG if mode == 'ccg': - st1 = st[clu2==jj] / ops['fs'] + st1 = st[:,0][clu2==jj] / ops['fs'] _, is_ccg, _ = CCG.check_CCG(st0, st1, acg_threshold=acg_threshold, ccg_threshold=ccg_threshold) else: @@ -311,9 +314,17 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch. if is_ccg: is_merged[jj] = 1 + dt = (imax[kk] -imax[jj]).item() + if dt != 0 and check_dt: + # Get spike indices for cluster jj + idx = (clu2 == jj) + # Update tF and Wall with shifted features + tF, Wall = roll_features(W, tF, Ww, idx, jj, dt) + # Shift spike times + st[idx,0] -= dt + Ww[kk] = ns[kk]/(ns[kk]+ns[jj]) * Ww[kk] + ns[jj]/(ns[kk]+ns[jj]) * Ww[jj] Ww[jj] = 0 - ns[kk] += ns[jj] ns[jj] = 0 clu2[clu2==jj] = kk @@ -337,4 +348,31 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch. else: is_ref = None - return Ww.cpu(), clu2, is_ref + sorted_idx = np.argsort(st[:,0]) + st = np.take_along_axis(st, sorted_idx[..., np.newaxis], axis=0) + clu2 = clu2[sorted_idx] + tensor_idx = torch.from_numpy(sorted_idx) + tF = tF[tensor_idx] + + return Ww.cpu(), clu2, is_ref, st, tF + + +def roll_features(wPCA, tF, Wall, spike_idx, clust_idx, dt): + W = wPCA.cpu() + # Project from PC space back to sample time, shift by dt + feats = torch.roll(tF[spike_idx] @ W, shifts=dt, dims=2) + temps = torch.roll(Wall[clust_idx:clust_idx+1] @ wPCA, shifts=dt, dims=2) + + # For values that "rolled over the edge," set equal to next closest bin + if dt > 0: + feats[:,:,:dt] = feats[:,:,dt].unsqueeze(-1) + temps[:,:,:dt] = temps[:,:,dt].unsqueeze(-1) + elif dt < 0: + feats[:,:,dt:] = feats[:,:,dt-1].unsqueeze(-1) + temps[:,:,dt:] = temps[:,:,dt-1].unsqueeze(-1) + + # Project back to PC space and update tF + tF[spike_idx] = feats @ W.T + Wall[clust_idx] = temps @ wPCA.T + + return tF, Wall