diff --git a/src/spikeinterface/preprocessing/detect_saturation.py b/src/spikeinterface/preprocessing/detect_saturation.py new file mode 100644 index 0000000000..5479b92284 --- /dev/null +++ b/src/spikeinterface/preprocessing/detect_saturation.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import numpy as np + +from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.node_pipeline import run_node_pipeline +from spikeinterface.core.node_pipeline import PeakDetector + + +EVENT_VECTOR_TYPE = [ + ("start_sample_index", "int64"), + ("stop_sample_index", "int64"), + ("segment_index", "int64"), + ("method_id", "U128"), +] + + +def _collapse_events(events): + """ + If events are detected at a chunk edge, they will be split in two. + This detects such cases and collapses them in a single record instead + :param events: + :return: + """ + order = np.lexsort((events["start_sample_index"], events["segment_index"])) + events = events[order] + to_drop = np.zeros(events.size, dtype=bool) + + # compute if duplicate + for i in np.arange(events.size - 1): + same = events["stop_sample_index"][i] == events["start_sample_index"][i + 1] + if same: + to_drop[i] = True + events["start_sample_index"][i + 1] = events["start_sample_index"][i] + + return events[~to_drop].copy() + + +class _DetectSaturation(PeakDetector): + + name = "detect_saturation" + preferred_mp_context = None + + def __init__( + self, + recording, + saturation_threshold, # 1200 uV + voltage_per_sec_threshold, # 1e-8 V.s-1 + proportion, + mute_window_samples, + ): + PeakDetector.__init__(self, recording, return_output=True) + + self.voltage_per_sec_threshold = voltage_per_sec_threshold + self.saturation_threshold = saturation_threshold + self.sampling_frequency = recording.get_sampling_frequency() + self.proportion = proportion + self.mute_window_samples = mute_window_samples + self._dtype = EVENT_VECTOR_TYPE + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return self._dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + """ + Computes + :param data: [nc, ns]: voltage traces array + :param max_voltage: maximum value of the voltage: scalar or array of size nc (same units as data) + :param v_per_sec: maximum derivative of the voltage in V/s (or units/s) + :param fs: sampling frequency Hz (defaults to 30kHz) + :param proportion: 0 < proportion <1 of channels above threshold to consider the sample as saturated (0.2) + :param mute_window_samples=7: number of samples for the cosine taper applied to the saturation + :return: + saturation [ns]: boolean array indicating the saturated samples + mute [ns]: float array indicating the mute function to apply to the data [0-1] + """ + fs = self.sampling_frequency + + # first computes the saturated samples + max_voltage = np.atleast_1d(self.saturation_threshold)[:, np.newaxis] + + # 0.98 is empirically determined as the true saturating point is + # slightly lower than the documented saturation point of the probe + saturation = np.mean(np.abs(traces) > max_voltage * 0.98, axis=1) + + # then compute the derivative of the voltage saturation + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.voltage_per_sec_threshold, axis=1) + # Note this means the velocity is not checked for the last sample in the + # check because we are taking the forward derivative + n_diff_saturated = np.r_[n_diff_saturated, 0] + + # if either of those reaches more than the proportion of channels labels the sample as saturated + saturation = np.logical_or(saturation > self.proportion, n_diff_saturated > self.proportion) + + intervals = np.where(np.diff(saturation, prepend=False, append=False))[0] + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=EVENT_VECTOR_TYPE) + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["stop_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index + events[i]["method_id"] = "saturation_detection" + + # Because we inherit PeakDetector, we must expose this "sample_index" + # array. However, it is not used and changing the value has no effect. + toto = np.array([0], dtype=[("sample_index", "int64")]) + + return (toto, events) + + +def detect_saturation( + recording, + saturation_threshold, # 1200 uV + voltage_per_sec_threshold, # 1e-8 V.s-1 + proportion=0.5, + mute_window_samples=7, + job_kwargs=None, +): + """ """ + if job_kwargs: + job_kwargs = {} + + job_kwargs = fix_job_kwargs(job_kwargs) + + node0 = _DetectSaturation( + recording, + saturation_threshold=saturation_threshold, + voltage_per_sec_threshold=voltage_per_sec_threshold, + proportion=proportion, + mute_window_samples=mute_window_samples, + ) + + _, events = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation events") + + return _collapse_events(events) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_saturation.py b/src/spikeinterface/preprocessing/tests/test_detect_saturation.py new file mode 100644 index 0000000000..5e3fee2525 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_detect_saturation.py @@ -0,0 +1,82 @@ +import numpy as np +import scipy.signal +import pandas as pd +from spikeinterface.core.numpyextractors import NumpyRecording +from spikeinterface.preprocessing.detect_saturation import detect_saturation + +# TODO: add pre-sets and document? or at least reccomend values in documentation probably easier + + +def test_saturation_detection(): + """ + TODO: NOTE: we have one sample before the saturation starts as we take the forward derivative for the velocity + we have an extra sample after due to taking the diff on the final saturation mask + this means we always take one sample before and one sample after the diff period, which is fine. + """ + sample_frequency = 30000 + chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below + job_kwargs = {"chunk_size": chunk_size} + + # cross a chunk boundary. Do not change without changing the below. + sat_value = 1200 * 1e-6 + data = np.random.uniform(low=-0.5, high=0.5, size=(150000, 384)) * 10 * 1e-6 + + # Design the Butterworth filter + sos = scipy.signal.butter(N=3, Wn=12000 / (sample_frequency / 2), btype="low", output="sos") + + # Apply the filter to the data + data_seg_1 = scipy.signal.sosfiltfilt(sos, data, axis=0) + data_seg_2 = data_seg_1.copy() + + # Add test saturation at the start, end of recording + # as well as across and within chunks (30k samples). + # Two cases which are not tested are a single event + # exactly on the border, as it makes testing complex + # This was checked manually and any future breaking change + # on this function would be extremely unlikely only to break this case. + # fmt:off + all_starts = np.array([0, 29950, 45123, 90005, 149500]) + all_stops = np.array([1000, 30010, 45125, 90005, 149998]) + # fmt:on + + second_seg_offset = 1 + for start, stop in zip(all_starts, all_stops): + if start == stop: + data_seg_1[start] = sat_value + else: + data_seg_1[start : stop + 1, :] = sat_value + # differentiate the second segment for testing purposes + data_seg_2[start : stop + 1 + second_seg_offset, :] = sat_value + + recording = NumpyRecording([data_seg_1, data_seg_2], sample_frequency) + + events = detect_saturation( + recording, saturation_threshold=sat_value * 0.98, voltage_per_sec_threshold=1e-8, job_kwargs=job_kwargs + ) + + seg_1_events = events[np.where(events["segment_index"] == 0)] + seg_2_events = events[np.where(events["segment_index"] == 1)] + + # For the start times, all are one sample before the actual saturated + # period starts because the derivative threshold is exceeded at one + # sample before the saturation starts. Therefore this one-sample-offset + # on the start times is an implicit test that the derivative + # threshold is working properly. + for seg_events in [seg_1_events, seg_2_events]: + assert seg_events["start_sample_index"][0] == all_starts[0] + assert np.array_equal(seg_events["start_sample_index"][1:], np.array(all_starts)[1:] - 1) + + assert np.array_equal(seg_1_events["stop_sample_index"], np.array(all_stops) + 1) + assert np.array_equal(seg_2_events["stop_sample_index"], np.array(all_stops) + 1 + second_seg_offset) + + # Just do a quick test that a threshold slightly over the sat value is not detected. + # In this case we only see the derivative threshold detection. We do not play around with this + # threshold because the derivative threshold is not easy to predict (the baseline sample is random). + events = detect_saturation( + recording, + saturation_threshold=sat_value * (1.0 / 0.98) + 1e-6, + voltage_per_sec_threshold=1e-8, + job_kwargs=job_kwargs, + ) + assert events["start_sample_index"][0] == 1000 + assert events["stop_sample_index"][0] == 1001