Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions src/spikeinterface/preprocessing/detect_saturation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from __future__ import annotations

import numpy as np

from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core.node_pipeline import run_node_pipeline
from spikeinterface.core.node_pipeline import PeakDetector


EVENT_VECTOR_TYPE = [
("start_sample_index", "int64"),
("stop_sample_index", "int64"),
("segment_index", "int64"),
("method_id", "U128"),
]


def _collapse_events(events):
"""
If events are detected at a chunk edge, they will be split in two.
This detects such cases and collapses them in a single record instead
:param events:
:return:
"""
order = np.lexsort((events["start_sample_index"], events["segment_index"]))
events = events[order]
to_drop = np.zeros(events.size, dtype=bool)

# compute if duplicate
for i in np.arange(events.size - 1):
same = events["stop_sample_index"][i] == events["start_sample_index"][i + 1]
if same:
to_drop[i] = True
events["start_sample_index"][i + 1] = events["start_sample_index"][i]

return events[~to_drop].copy()


class _DetectSaturation(PeakDetector):

name = "detect_saturation"
preferred_mp_context = None

def __init__(
self,
recording,
saturation_threshold, # 1200 uV
voltage_per_sec_threshold, # 1e-8 V.s-1
proportion,
mute_window_samples,
):
PeakDetector.__init__(self, recording, return_output=True)

self.voltage_per_sec_threshold = voltage_per_sec_threshold
self.saturation_threshold = saturation_threshold
self.sampling_frequency = recording.get_sampling_frequency()
self.proportion = proportion
self.mute_window_samples = mute_window_samples
self._dtype = EVENT_VECTOR_TYPE

def get_trace_margin(self):
return 0

def get_dtype(self):
return self._dtype

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
"""
Computes
:param data: [nc, ns]: voltage traces array
:param max_voltage: maximum value of the voltage: scalar or array of size nc (same units as data)
:param v_per_sec: maximum derivative of the voltage in V/s (or units/s)
:param fs: sampling frequency Hz (defaults to 30kHz)
:param proportion: 0 < proportion <1 of channels above threshold to consider the sample as saturated (0.2)
:param mute_window_samples=7: number of samples for the cosine taper applied to the saturation
:return:
saturation [ns]: boolean array indicating the saturated samples
mute [ns]: float array indicating the mute function to apply to the data [0-1]
"""
fs = self.sampling_frequency

# first computes the saturated samples
max_voltage = np.atleast_1d(self.saturation_threshold)[:, np.newaxis]

# 0.98 is empirically determined as the true saturating point is
# slightly lower than the documented saturation point of the probe
saturation = np.mean(np.abs(traces) > max_voltage * 0.98, axis=1)

# then compute the derivative of the voltage saturation
n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.voltage_per_sec_threshold, axis=1)
# Note this means the velocity is not checked for the last sample in the
# check because we are taking the forward derivative
n_diff_saturated = np.r_[n_diff_saturated, 0]

# if either of those reaches more than the proportion of channels labels the sample as saturated
saturation = np.logical_or(saturation > self.proportion, n_diff_saturated > self.proportion)

intervals = np.where(np.diff(saturation, prepend=False, append=False))[0]
n_events = len(intervals) // 2 # Number of saturation periods
events = np.zeros(n_events, dtype=EVENT_VECTOR_TYPE)

for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])):
events[i]["start_sample_index"] = start + start_frame
events[i]["stop_sample_index"] = stop + start_frame
events[i]["segment_index"] = segment_index
events[i]["method_id"] = "saturation_detection"

# Because we inherit PeakDetector, we must expose this "sample_index"
# array. However, it is not used and changing the value has no effect.
toto = np.array([0], dtype=[("sample_index", "int64")])

return (toto, events)


def detect_saturation(
recording,
saturation_threshold, # 1200 uV
voltage_per_sec_threshold, # 1e-8 V.s-1
proportion=0.5,
mute_window_samples=7,
job_kwargs=None,
):
""" """
if job_kwargs:
job_kwargs = {}

job_kwargs = fix_job_kwargs(job_kwargs)

node0 = _DetectSaturation(
recording,
saturation_threshold=saturation_threshold,
voltage_per_sec_threshold=voltage_per_sec_threshold,
proportion=proportion,
mute_window_samples=mute_window_samples,
)

_, events = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation events")

return _collapse_events(events)
82 changes: 82 additions & 0 deletions src/spikeinterface/preprocessing/tests/test_detect_saturation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
import scipy.signal
import pandas as pd
from spikeinterface.core.numpyextractors import NumpyRecording
from spikeinterface.preprocessing.detect_saturation import detect_saturation

# TODO: add pre-sets and document? or at least reccomend values in documentation probably easier


def test_saturation_detection():
"""
TODO: NOTE: we have one sample before the saturation starts as we take the forward derivative for the velocity
we have an extra sample after due to taking the diff on the final saturation mask
this means we always take one sample before and one sample after the diff period, which is fine.
"""
sample_frequency = 30000
chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below
job_kwargs = {"chunk_size": chunk_size}

# cross a chunk boundary. Do not change without changing the below.
sat_value = 1200 * 1e-6
data = np.random.uniform(low=-0.5, high=0.5, size=(150000, 384)) * 10 * 1e-6

# Design the Butterworth filter
sos = scipy.signal.butter(N=3, Wn=12000 / (sample_frequency / 2), btype="low", output="sos")

# Apply the filter to the data
data_seg_1 = scipy.signal.sosfiltfilt(sos, data, axis=0)
data_seg_2 = data_seg_1.copy()

# Add test saturation at the start, end of recording
# as well as across and within chunks (30k samples).
# Two cases which are not tested are a single event
# exactly on the border, as it makes testing complex
# This was checked manually and any future breaking change
# on this function would be extremely unlikely only to break this case.
# fmt:off
all_starts = np.array([0, 29950, 45123, 90005, 149500])
all_stops = np.array([1000, 30010, 45125, 90005, 149998])
# fmt:on

second_seg_offset = 1
for start, stop in zip(all_starts, all_stops):
if start == stop:
data_seg_1[start] = sat_value
else:
data_seg_1[start : stop + 1, :] = sat_value
# differentiate the second segment for testing purposes
data_seg_2[start : stop + 1 + second_seg_offset, :] = sat_value

recording = NumpyRecording([data_seg_1, data_seg_2], sample_frequency)

events = detect_saturation(
recording, saturation_threshold=sat_value * 0.98, voltage_per_sec_threshold=1e-8, job_kwargs=job_kwargs
)

seg_1_events = events[np.where(events["segment_index"] == 0)]
seg_2_events = events[np.where(events["segment_index"] == 1)]

# For the start times, all are one sample before the actual saturated
# period starts because the derivative threshold is exceeded at one
# sample before the saturation starts. Therefore this one-sample-offset
# on the start times is an implicit test that the derivative
# threshold is working properly.
for seg_events in [seg_1_events, seg_2_events]:
assert seg_events["start_sample_index"][0] == all_starts[0]
assert np.array_equal(seg_events["start_sample_index"][1:], np.array(all_starts)[1:] - 1)

assert np.array_equal(seg_1_events["stop_sample_index"], np.array(all_stops) + 1)
assert np.array_equal(seg_2_events["stop_sample_index"], np.array(all_stops) + 1 + second_seg_offset)

# Just do a quick test that a threshold slightly over the sat value is not detected.
# In this case we only see the derivative threshold detection. We do not play around with this
# threshold because the derivative threshold is not easy to predict (the baseline sample is random).
events = detect_saturation(
recording,
saturation_threshold=sat_value * (1.0 / 0.98) + 1e-6,
voltage_per_sec_threshold=1e-8,
job_kwargs=job_kwargs,
)
assert events["start_sample_index"][0] == 1000
assert events["stop_sample_index"][0] == 1001
Loading