diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 74ef52e258..d4ed5e8e37 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -1105,6 +1105,9 @@ def _compute_metrics( ------- metrics : pd.DataFrame DataFrame containing the computed metrics for each unit. + run_times : dict + Dictionary containing the computation time for each metric. + """ import pandas as pd @@ -1121,11 +1124,17 @@ def _compute_metrics( metrics = pd.DataFrame(index=unit_ids, columns=list(column_names_dtypes.keys())) + run_times = {} + for metric_name in metric_names: metric = [m for m in self.metric_list if m.metric_name == metric_name][0] column_names = list(metric.metric_columns.keys()) + import time + + t_start = time.perf_counter() try: metric_params = self.params["metric_params"].get(metric_name, {}) + res = metric.compute( sorting_analyzer, unit_ids=unit_ids, @@ -1139,6 +1148,8 @@ def _compute_metrics( res = {unit_id: np.nan for unit_id in unit_ids} else: res = namedtuple("MetricResult", column_names)(*([np.nan] * len(column_names))) + t_end = time.perf_counter() + run_times[metric_name] = t_end - t_start # res is a namedtuple with several dictionary entries (one per column) if isinstance(res, dict): @@ -1150,7 +1161,7 @@ def _compute_metrics( metrics = self._cast_metrics(metrics) - return metrics + return metrics, run_times def _run(self, **job_kwargs): @@ -1161,7 +1172,7 @@ def _run(self, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) # compute the metrics which have been specified by the user - computed_metrics = self._compute_metrics( + computed_metrics, run_times = self._compute_metrics( sorting_analyzer=self.sorting_analyzer, unit_ids=None, metric_names=metrics_to_compute, **job_kwargs ) @@ -1189,6 +1200,7 @@ def _run(self, **job_kwargs): computed_metrics[column_name] = extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics + self.data["runtime_s"] = run_times def _get_data(self): # convert to correct dtype @@ -1265,7 +1277,7 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics( + metrics.loc[new_unit_ids, :], _ = self._compute_metrics( sorting_analyzer=new_sorting_analyzer, unit_ids=new_unit_ids, metric_names=metric_names, **job_kwargs ) metrics = self._cast_metrics(metrics) @@ -1309,7 +1321,7 @@ def _split_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids_f, :] = self._compute_metrics( + metrics.loc[new_unit_ids_f, :], _ = self._compute_metrics( sorting_analyzer=new_sorting_analyzer, unit_ids=new_unit_ids_f, metric_names=metric_names, **job_kwargs ) metrics = self._cast_metrics(metrics) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index c6b07da52e..a4e60b064d 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -67,7 +67,7 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + # seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] total_length = sorting_analyzer.get_total_samples() total_duration = sorting_analyzer.get_total_duration() bin_duration_samples = int((bin_duration_s * sorting_analyzer.sampling_frequency)) @@ -89,13 +89,16 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 ) presence_ratios = {unit_id: np.nan for unit_id in unit_ids} else: + + spikes = sorting.to_spike_vector() + order = np.lexsort((spikes["sample_index"], spikes["segment_index"], spikes["unit_index"])) + new_spikes = spikes[order] + + # precompute segment slice for unit_id in unit_ids: - spike_train = [] - for segment_index in range(num_segs): - st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(seg_lengths[:segment_index]) - spike_train.append(st) - spike_train = np.concatenate(spike_train) + unit_index = sorting.id_to_index(unit_id) + u0, u1 = np.searchsorted(new_spikes["unit_index"], [unit_index, unit_index + 1], side="left") + spike_train = new_spikes[u0:u1]["sample_index"] unit_fr = spike_train.size / total_duration bin_n_spikes_thres = math.floor(unit_fr * bin_duration_s * mean_fr_ratio_thresh) @@ -248,12 +251,20 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 isi_violations_count = {} isi_violations_ratio = {} - # all units converted to seconds + spikes = sorting.to_spike_vector() + order = np.lexsort((spikes["sample_index"], spikes["segment_index"], spikes["unit_index"])) + new_spikes = spikes[order] + + # precompute segment slice + unit_slices = {} for unit_id in unit_ids: + unit_index = sorting.id_to_index(unit_id) + u0, u1 = np.searchsorted(new_spikes["unit_index"], [unit_index, unit_index + 1], side="left") + sub_data = new_spikes[u0:u1] spike_train_list = [] - - for segment_index in range(num_segs): - spike_train = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + for segment_index in range(sorting_analyzer.get_num_segments()): + s0, s1 = np.searchsorted(sub_data["segment_index"], [segment_index, segment_index + 1], side="left") + spike_train = new_spikes[u0 + s0 : u0 + s1]["sample_index"] if len(spike_train) > 0: spike_train_list.append(spike_train / fs) @@ -347,21 +358,22 @@ def compute_refrac_period_violations( t_r = int(round(refractory_period_ms * fs * 1e-3)) nb_rp_violations = np.zeros((num_units), dtype=np.int64) + all_unit_indices = sorting.ids_to_indices(unit_ids) + for seg_index in range(num_segments): spike_times = spikes[seg_index]["sample_index"].astype(np.int64) spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) - _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) + _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r, all_unit_indices) T = sorting_analyzer.get_total_samples() nb_violations = {} rp_contamination = {} - for unit_index, unit_id in enumerate(sorting.unit_ids): - if unit_id not in unit_ids: - continue - - nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] + for unit_id in unit_ids: + unit_index = sorting.id_to_index(unit_id) + n_v = nb_rp_violations[unit_index] + nb_violations[unit_id] = n_v N = num_spikes[unit_id] if N == 0: rp_contamination[unit_id] = np.nan @@ -392,6 +404,7 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, + correlograms_kwargs=dict(method="auto"), ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -417,6 +430,8 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). + correlograms_kwargs : dict, default: dict("method"="auto") + Additional keyword arguments to pass to `correlogram_for_one_segment`. Returns ------- @@ -433,37 +448,40 @@ def compute_sliding_rp_violations( sorting = sorting_analyzer.sorting if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - num_segs = sorting_analyzer.get_num_segments() + fs = sorting_analyzer.sampling_frequency contamination = {} - # all units converted to seconds - for unit_id in unit_ids: - spike_train_list = [] + spikes = sorting.to_spike_vector() + order = np.lexsort((spikes["sample_index"], spikes["segment_index"], spikes["unit_index"])) + new_spikes = spikes[order] - for segment_index in range(num_segs): - spike_train = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - if np.any(spike_train): - spike_train_list.append(spike_train) + for unit_id in unit_ids: + unit_index = sorting.id_to_index(unit_id) + u0, u1 = np.searchsorted(new_spikes["unit_index"], [unit_index, unit_index + 1], side="left") - if not any([np.any(train) for train in spike_train_list]): - continue + sub_spikes = new_spikes[u0:u1].copy() + sub_spikes["unit_index"] = 0 # single unit sorting - unit_n_spikes = np.sum([len(train) for train in spike_train_list]) + unit_n_spikes = len(sub_spikes) if unit_n_spikes <= min_spikes: contamination[unit_id] = np.nan continue + from spikeinterface.core.numpyextractors import NumpySorting + + sub_sorting = NumpySorting(sub_spikes, fs, unit_ids=[unit_id]) + contamination[unit_id] = slidingRP_violations( - spike_train_list, - fs, + sub_sorting, duration, bin_size_ms, window_size_s, exclude_ref_period_below_ms, max_ref_period_ms, contamination_values, + correlograms_kwargs=correlograms_kwargs, ) return contamination @@ -479,6 +497,7 @@ class SlidingRPViolation(BaseMetric): "exclude_ref_period_below_ms": 0.5, "max_ref_period_ms": 10, "contamination_values": None, + "correlograms_kwargs": dict(method="auto"), } metric_columns = {"sliding_rp_violation": float} metric_descriptions = { @@ -528,6 +547,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids + synchrony_counts = _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids) synchrony_metrics_dict = {} @@ -596,14 +616,21 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.") return {unit_id: np.nan for unit_id in unit_ids} + spikes = sorting.to_spike_vector() + order = np.lexsort((spikes["sample_index"], spikes["unit_index"], spikes["segment_index"])) + new_spikes = spikes[order] + # for each segment, we compute the firing rate histogram and we concatenate them - firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} + firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in unit_ids} for segment_index in range(sorting_analyzer.get_num_segments()): + s0, s1 = np.searchsorted(new_spikes["segment_index"], [segment_index, segment_index + 1], side="left") num_samples = sorting_analyzer.get_num_samples(segment_index) edges = np.arange(0, num_samples + 1, bin_size_samples) - + sub_data = new_spikes[s0:s1] for unit_id in unit_ids: - spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + unit_index = sorting.id_to_index(unit_id) + u0, u1 = np.searchsorted(sub_data["unit_index"], [unit_index, unit_index + 1], side="left") + spike_times = new_spikes[s0 + u0 : s0 + u1]["sample_index"] spike_counts, _ = np.histogram(spike_times, bins=edges) firing_rates = spike_counts / bin_size_s firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) @@ -685,14 +712,10 @@ def compute_amplitude_cv_metrics( amps = sorting_analyzer.get_extension(amplitude_extension).get_data() - # precompute segment slice - segment_slices = [] - for segment_index in range(sorting_analyzer.get_num_segments()): - i0 = np.searchsorted(spikes["segment_index"], segment_index) - i1 = np.searchsorted(spikes["segment_index"], segment_index + 1) - segment_slices.append(slice(i0, i1)) + order = np.lexsort((spikes["sample_index"], spikes["segment_index"], spikes["unit_index"])) + new_spikes = spikes[order] + new_amps = amps[order] - all_unit_ids = list(sorting.unit_ids) amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: firing_rate = num_spikes[unit_id] / total_duration @@ -700,21 +723,26 @@ def compute_amplitude_cv_metrics( (average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency ) + unit_index = sorting.id_to_index(unit_id) + u0, u1 = np.searchsorted(new_spikes["unit_index"], [unit_index, unit_index + 1], side="left") + sub_data = new_spikes[u0:u1] + amp_spreads = [] # bins and amplitude means are computed for each segment for segment_index in range(sorting_analyzer.get_num_segments()): sample_bin_edges = np.arange( 0, sorting_analyzer.get_num_samples(segment_index) + 1, temporal_bin_size_samples ) - spikes_in_segment = spikes[segment_slices[segment_index]] - amps_in_segment = amps[segment_slices[segment_index]] - unit_mask = spikes_in_segment["unit_index"] == all_unit_ids.index(unit_id) - spike_indices_unit = spikes_in_segment["sample_index"][unit_mask] - amps_unit = amps_in_segment[unit_mask] + + s0, s1 = np.searchsorted(sub_data["segment_index"], [segment_index, segment_index + 1], side="left") + spikes_in_segment = new_spikes[u0 + s0 : u0 + s1] + amps_unit = new_amps[u0 + s0 : u0 + s1] + spike_indices_unit = spikes_in_segment["sample_index"] amp_mean = np.abs(np.mean(amps_unit)) - for t0, t1 in zip(sample_bin_edges[:-1], sample_bin_edges[1:]): - i0 = np.searchsorted(spike_indices_unit, t0) - i1 = np.searchsorted(spike_indices_unit, t1) + + bounds = np.searchsorted(spike_indices_unit, sample_bin_edges, side="left") + + for i0, i1 in zip(bounds[:-1], bounds[1:]): amp_spreads.append(np.std(amps_unit[i0:i1]) / amp_mean) if len(amp_spreads) < min_num_bins: @@ -1033,14 +1061,15 @@ def compute_drift_metrics( spike_locations_ext = sorting_analyzer.get_extension("spike_locations") spike_locations = spike_locations_ext.get_data() - # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") + spikes = sorting.to_spike_vector() - spike_locations_by_unit = {} - for unit_id in unit_ids: - unit_index = sorting.id_to_index(unit_id) - # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code - spike_mask = spikes["unit_index"] == unit_index - spike_locations_by_unit[unit_id] = spike_locations[spike_mask] + order = np.lexsort((spikes["sample_index"], spikes["segment_index"], spikes["unit_index"])) + new_spikes = spikes[order] + new_spike_locations = spike_locations[order] + + order_bis = np.lexsort((spikes["sample_index"], spikes["unit_index"], spikes["segment_index"])) + new_spikes_bis = spikes[order] + new_spike_locations_bis = spike_locations[order] interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) assert direction in spike_locations.dtype.names, ( @@ -1065,35 +1094,36 @@ def compute_drift_metrics( # reference positions are the medians across segments reference_positions = np.zeros(len(unit_ids)) + median_position_segments = None + for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - reference_positions[i] = np.median(spike_locations_by_unit[unit_id][direction]) + unit_index = sorting.id_to_index(unit_id) + u0, u1 = np.searchsorted(new_spikes["unit_index"], [unit_index, unit_index + 1], side="left") + reference_positions[i] = np.median(new_spike_locations[u0:u1][direction]) - # now compute median positions and concatenate them over segments - median_position_segments = None for segment_index in range(sorting_analyzer.get_num_segments()): + s0, s1 = np.searchsorted(new_spikes_bis["segment_index"], [segment_index, segment_index + 1], side="left") seg_length = sorting_analyzer.get_num_samples(segment_index) num_bin_edges = seg_length // interval_samples + 1 bins = np.arange(num_bin_edges) * interval_samples - spike_vector = sorting.to_spike_vector() + spikes_in_segment = new_spikes_bis[s0:s1] + spike_locations_in_segment = new_spike_locations_bis[s0:s1] - # retrieve spikes in segment - i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) - spikes_in_segment = spike_vector[i0:i1] - spike_locations_in_segment = spike_locations[i0:i1] - - # compute median positions (if less than min_spikes_per_interval, median position is 0) median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) - for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) - spikes_in_bin = spikes_in_segment[i0:i1] - spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] - - for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - mask = spikes_in_bin["unit_index"] == unit_ind - if np.sum(mask) >= min_spikes_per_interval: - median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) + + for i, unit_id in enumerate(unit_ids): + unit_index = sorting.id_to_index(unit_id) + u0, u1 = np.searchsorted(spikes_in_segment["unit_index"], [unit_index, unit_index + 1], side="left") + spikes_in_segment_of_unit = spikes_in_segment[u0:u1] + spike_locations_in_segment_of_unit = spike_locations_in_segment[u0:u1] + + for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): + i0, i1 = np.searchsorted(spikes_in_segment_of_unit["sample_index"], [start_frame, end_frame]) + spikes_in_bin = spikes_in_segment_of_unit[i0:i1] + spike_locations_in_bin = spike_locations_in_segment_of_unit[i0:i1][direction] + if len(spikes_in_bin) >= min_spikes_per_interval: + median_positions[i, bin_index] = np.median(spike_locations_in_bin) + if median_position_segments is None: median_position_segments = median_positions else: @@ -1220,27 +1250,33 @@ def compute_sd_ratio( tamplates_array = get_dense_templates_array(sorting_analyzer, return_in_uV=sorting_analyzer.return_in_uV) spikes = sorting.to_spike_vector() + order = np.lexsort((spikes["sample_index"], spikes["segment_index"], spikes["unit_index"])) + new_spikes = spikes[order] + new_spike_amplitudes = spike_amplitudes[order] + sd_ratio = {} + for unit_id in unit_ids: - unit_index = sorting_analyzer.sorting.id_to_index(unit_id) + unit_index = sorting.id_to_index(unit_id) + u0, u1 = np.searchsorted(new_spikes["unit_index"], [unit_index, unit_index + 1], side="left") + sub_data = new_spikes[u0:u1] spk_amp = [] - for segment_index in range(sorting_analyzer.get_num_segments()): + for segment_index in range(sorting.get_num_segments()): - spike_mask = (spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index) - spike_train = spikes[spike_mask]["sample_index"].astype(np.int64, copy=False) - amplitudes = spike_amplitudes[spike_mask] + s0, s1 = np.searchsorted(sub_data["segment_index"], [segment_index, segment_index + 1], side="left") + spike_train = new_spikes[u0 + s0 : u0 + s1]["sample_index"] + amplitudes = new_spike_amplitudes[u0 + s0 : u0 + s1] censored_indices = find_duplicated_spikes( spike_train, censored_period, method="keep_first_iterative", ) - spk_amp.append(np.delete(amplitudes, censored_indices)) - spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))]) + spk_amp = np.concatenate(spk_amp) if len(spk_amp) == 0: sd_ratio[unit_id] = np.nan @@ -1255,6 +1291,8 @@ def compute_sd_ratio( best_channel = best_channels[unit_id] std_noise = noise_levels[best_channel] + n_samples = sorting_analyzer.get_total_samples() + if correct_for_template_itself: # template = sorting_analyzer.get_template(unit_id, force_dense=True)[:, best_channel] @@ -1263,7 +1301,7 @@ def compute_sd_ratio( # Computing the variance of a trace that is all 0 and n_spikes non-overlapping template. # TODO: Take into account that templates for different segments might differ. - p = nsamples * n_spikes[unit_id] / sorting_analyzer.get_total_samples() + p = nsamples * n_spikes[unit_id] / n_samples total_variance = p * np.mean(template**2) - p**2 * np.mean(template) ** 2 std_noise = np.sqrt(std_noise**2 - total_variance) @@ -1475,8 +1513,7 @@ def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_val def slidingRP_violations( - spike_samples, - sample_rate, + sorting, duration, bin_size_ms=0.25, window_size_s=1, @@ -1484,6 +1521,7 @@ def slidingRP_violations( max_ref_period_ms=10, contamination_values=None, return_conf_matrix=False, + correlograms_kwargs=dict(method="auto"), ): """ A metric developed by IBL which determines whether the refractory period violations @@ -1495,8 +1533,6 @@ def slidingRP_violations( ---------- spike_samples : ndarray_like or list (for multi-segment) The spike times in samples. - sample_rate : float - The acquisition sampling rate. bin_size_ms : float The size (in ms) of binning for the autocorrelogram. window_size_s : float, default: 1 @@ -1509,6 +1545,8 @@ def slidingRP_violations( The contamination values to test, if None it is set to np.arange(0.5, 35, 0.5) / 100. return_conf_matrix : bool, default: False If True, the confidence matrix (n_contaminations, n_ref_periods) is returned. + correlograms_kwargs : dict, default: {} + Additional keyword arguments for correlogram computation. Code adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/master/python/slidingRP/metrics.py#L166 @@ -1525,25 +1563,28 @@ def slidingRP_violations( rp_centers = rp_edges + ((rp_edges[1] - rp_edges[0]) / 2) # vector of refractory period durations to test # compute firing rate and spike count (concatenate for multi-segments) - n_spikes = len(np.concatenate(spike_samples)) + n_spikes = len(sorting.to_spike_vector()) firing_rate = n_spikes / duration - if np.isscalar(spike_samples[0]): - spike_samples_list = [spike_samples] - else: - spike_samples_list = spike_samples - # compute correlograms - correlogram = None - for spike_samples in spike_samples_list: - c0 = correlogram_for_one_segment( - spike_samples, - np.zeros(len(spike_samples), dtype="int8"), - bin_size=max(int(bin_size_ms / 1000 * sample_rate), 1), # convert to sample counts - window_size=int(window_size_s * sample_rate), - )[0, 0] - if correlogram is None: - correlogram = c0 - else: - correlogram += c0 + + method = correlograms_kwargs.get("method", "auto") + if method == "auto": + method = "numba" if HAVE_NUMBA else "numpy" + + bin_size = max(int(bin_size_ms / 1000 * sorting.sampling_frequency), 1) + window_size = int(window_size_s * sorting.sampling_frequency) + + if method == "numpy": + from spikeinterface.postprocessing.correlograms import _compute_correlograms_numpy + + correlogram = _compute_correlograms_numpy(sorting, window_size, bin_size)[0, 0] + if method == "numba": + from spikeinterface.postprocessing.correlograms import _compute_correlograms_numba + + correlogram = _compute_correlograms_numba(sorting, window_size, bin_size)[0, 0] + + ## I dont get why this line is not giving exactly the same result as the correlogram function. I would question + # the choice of the bin_size above, but I am not the author of the code... + # correlogram = compute_correlograms(sorting, 2*window_size_s*1000, bin_size_ms, **correlograms_kwargs)[0][0, 0] correlogram_positive = correlogram[len(correlogram) // 2 :] conf_matrix = _compute_violations( @@ -1693,18 +1734,25 @@ def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): # compute the occurrence of each sample_index. Count >2 means there's synchrony _, unique_spike_index, counts = np.unique(spikes["sample_index"], return_index=True, return_counts=True) - sync_indices = unique_spike_index[counts >= 2] - sync_counts = counts[counts >= 2] + min_synchrony = 2 + mask = counts >= min_synchrony + sync_indices = unique_spike_index[mask] + sync_counts = counts[mask] + + all_syncs = np.unique(sync_counts) + num_bins = [np.size(synchrony_sizes[synchrony_sizes <= i]) for i in all_syncs] + + indices = {} + for num_of_syncs in all_syncs: + indices[num_of_syncs] = np.flatnonzero(all_syncs == num_of_syncs)[0] for i, sync_index in enumerate(sync_indices): num_of_syncs = sync_counts[i] - units_with_sync = [spikes[sync_index + a][1] for a in range(0, num_of_syncs)] - # Counts inclusively. E.g. if there are 3 simultaneous spikes, these are also added # to the 2 simultaneous spike bins. - how_many_bins_to_add_to = np.size(synchrony_sizes[synchrony_sizes <= num_of_syncs]) - synchrony_counts[:how_many_bins_to_add_to, units_with_sync] += 1 + units_with_sync = spikes[sync_index : sync_index + num_of_syncs]["unit_index"] + synchrony_counts[: num_bins[indices[num_of_syncs]], units_with_sync] += 1 return synchrony_counts @@ -1760,10 +1808,10 @@ def _compute_nb_violations_numba(spike_train, t_r): cache=False, parallel=True, ) - def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters, t_c, t_r): - n_units = len(nb_rp_violations) + def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters, t_c, t_r, all_units_indices): - for i in numba.prange(n_units): - spike_train = spike_trains[spike_clusters == i] + for i in numba.prange(len(all_units_indices)): + unit_index = all_units_indices[i] + spike_train = spike_trains[spike_clusters == unit_index] n_v = _compute_nb_violations_numba(spike_train, t_r) - nb_rp_violations[i] += n_v + nb_rp_violations[unit_index] += n_v diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index c0dd6c6033..1acc833875 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -239,20 +239,22 @@ def test_unit_structure_in_output(small_sorting_analyzer): result_all = metric_fun(sorting_analyzer=small_sorting_analyzer, **qm_param) result_sub = metric_fun(sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param) + error = "Problem with metric: " + metric_name + if isinstance(result_all, dict): - assert list(result_all.keys()) == ["#3", "#9", "#4"] - assert list(result_sub.keys()) == ["#4", "#9"] - assert result_sub["#9"] == result_all["#9"] - assert result_sub["#4"] == result_all["#4"] + assert list(result_all.keys()) == ["#3", "#9", "#4"], error + assert list(result_sub.keys()) == ["#4", "#9"], error + assert result_sub["#9"] == result_all["#9"], error + assert result_sub["#4"] == result_all["#4"], error else: for result_ind, result in enumerate(result_sub): - assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"] - assert result_sub[result_ind].keys() == set(["#4", "#9"]) + assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"], error + assert result_sub[result_ind].keys() == set(["#4", "#9"]), error - assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"] - assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"] + assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"], error + assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"], error def test_unit_id_order_independence(small_sorting_analyzer): @@ -297,9 +299,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): ) for metric, metric_2_data in quality_metrics_2.items(): - assert quality_metrics_1[metric]["#3"] == metric_2_data[2] - assert quality_metrics_1[metric]["#9"] == metric_2_data[7] - assert quality_metrics_1[metric]["#4"] == metric_2_data[1] + error = "Problem with the metric " + metric + assert quality_metrics_1[metric]["#3"] == metric_2_data[2], error + assert quality_metrics_1[metric]["#9"] == metric_2_data[7], error + assert quality_metrics_1[metric]["#4"] == metric_2_data[1], error def _sorting_violation(): diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index ba66d0671c..7d6a0e67d2 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -25,12 +25,23 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): num_segs = sorting.get_num_segments() num_spikes = {} + + # for unit_id in unit_ids: + # n = 0 + # for segment_index in range(num_segs): + # st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + # n += st.size + # num_spikes[unit_id] = n + + spikes = sorting.to_spike_vector() + unit_indices, total_counts = np.unique(spikes["unit_index"], return_counts=True) for unit_id in unit_ids: - n = 0 - for segment_index in range(num_segs): - st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - n += st.size - num_spikes[unit_id] = n + unit_index = sorting.id_to_index(unit_id) + if unit_index in unit_indices: + idx = np.flatnonzero(unit_indices == unit_index) + num_spikes[unit_id] = total_counts[idx] + else: + num_spikes[unit_id] = 0 return num_spikes