-
Notifications
You must be signed in to change notification settings - Fork 240
Compute saturation intervals ala IBL #4301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
oliche
wants to merge
22
commits into
SpikeInterface:main
Choose a base branch
from
int-brain-lab:sat_detection
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+222
−0
Draft
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
eb351bc
Add draft test.
JoeZiminski bffc5f2
Vis to test.
JoeZiminski 017b7ce
in-progress detect saturation rough.
JoeZiminski 231c21f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a38c0f7
Complete arguments.
JoeZiminski ddfd3b0
Merge branch 'sat_detection' of github.com:SpikeInterface/spikeinterf…
JoeZiminski 79d91df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] be9f5b1
remove unused arguments.
JoeZiminski f8e514c
Remove unused arguments.
JoeZiminski 2ec43d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8494ae0
Fix removed call to superclass init.
JoeZiminski 10f5474
Merge branch 'sat_detection' of github.com:SpikeInterface/spikeinterf…
JoeZiminski 0c356ca
WIP compute saturation intervals
oliche bfa5d24
minor refactor.
JoeZiminski 8402a89
Add asserts to tests.
JoeZiminski 6e95425
Add asserts to tests
JoeZiminski 22aa4bf
Finalise tests.
JoeZiminski 39ec5e4
detect saturation.
JoeZiminski 6e5786d
Fix transpose.
JoeZiminski d43a561
Add note.
JoeZiminski 69baa3e
Rename to test_detect-saturation.
JoeZiminski a9e762d
Remove TODO.
JoeZiminski File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
|
|
||
| from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording | ||
| from spikeinterface.core.job_tools import fix_job_kwargs | ||
| from spikeinterface.core.node_pipeline import run_node_pipeline | ||
| from spikeinterface.core.node_pipeline import PeakDetector | ||
|
|
||
|
|
||
| EVENT_VECTOR_TYPE = [ | ||
| ("start_sample_index", "int64"), | ||
| ("stop_sample_index", "int64"), | ||
| ("segment_index", "int64"), | ||
| ("method_id", "U128"), | ||
| ] | ||
|
|
||
|
|
||
| def _collapse_events(events): | ||
| """ | ||
| If events are detected at a chunk edge, they will be split in two. | ||
| This detects such cases and collapses them in a single record instead | ||
| :param events: | ||
| :return: | ||
| """ | ||
| order = np.lexsort((events["start_sample_index"], events["segment_index"])) | ||
| events = events[order] | ||
| to_drop = np.zeros(events.size, dtype=bool) | ||
|
|
||
| # compute if duplicate | ||
| for i in np.arange(events.size - 1): | ||
| same = events["stop_sample_index"][i] == events["start_sample_index"][i + 1] | ||
| if same: | ||
| to_drop[i] = True | ||
| events["start_sample_index"][i + 1] = events["start_sample_index"][i] | ||
|
|
||
| return events[~to_drop].copy() | ||
|
|
||
|
|
||
| class _DetectSaturation(PeakDetector): | ||
|
|
||
| name = "detect_saturation" | ||
| preferred_mp_context = None | ||
|
|
||
| def __init__( | ||
| self, | ||
| recording, | ||
| saturation_threshold, # 1200 uV | ||
| voltage_per_sec_threshold, # 1e-8 V.s-1 | ||
| proportion, | ||
| mute_window_samples, | ||
| ): | ||
| PeakDetector.__init__(self, recording, return_output=True) | ||
|
|
||
| self.voltage_per_sec_threshold = voltage_per_sec_threshold | ||
| self.saturation_threshold = saturation_threshold | ||
| self.sampling_frequency = recording.get_sampling_frequency() | ||
| self.proportion = proportion | ||
| self.mute_window_samples = mute_window_samples | ||
| self._dtype = EVENT_VECTOR_TYPE | ||
|
|
||
| def get_trace_margin(self): | ||
| return 0 | ||
|
|
||
| def get_dtype(self): | ||
| return self._dtype | ||
|
|
||
| def compute(self, traces, start_frame, end_frame, segment_index, max_margin): | ||
| """ | ||
| Computes | ||
| :param data: [nc, ns]: voltage traces array | ||
| :param max_voltage: maximum value of the voltage: scalar or array of size nc (same units as data) | ||
| :param v_per_sec: maximum derivative of the voltage in V/s (or units/s) | ||
| :param fs: sampling frequency Hz (defaults to 30kHz) | ||
| :param proportion: 0 < proportion <1 of channels above threshold to consider the sample as saturated (0.2) | ||
| :param mute_window_samples=7: number of samples for the cosine taper applied to the saturation | ||
| :return: | ||
| saturation [ns]: boolean array indicating the saturated samples | ||
| mute [ns]: float array indicating the mute function to apply to the data [0-1] | ||
| """ | ||
| fs = self.sampling_frequency | ||
|
|
||
| # first computes the saturated samples | ||
| max_voltage = np.atleast_1d(self.saturation_threshold)[:, np.newaxis] | ||
|
|
||
| # 0.98 is empirically determined as the true saturating point is | ||
| # slightly lower than the documented saturation point of the probe | ||
| saturation = np.mean(np.abs(traces) > max_voltage * 0.98, axis=1) | ||
|
|
||
| # then compute the derivative of the voltage saturation | ||
| n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.voltage_per_sec_threshold, axis=1) | ||
| # Note this means the velocity is not checked for the last sample in the | ||
| # check because we are taking the forward derivative | ||
| n_diff_saturated = np.r_[n_diff_saturated, 0] | ||
|
|
||
| # if either of those reaches more than the proportion of channels labels the sample as saturated | ||
| saturation = np.logical_or(saturation > self.proportion, n_diff_saturated > self.proportion) | ||
|
|
||
| intervals = np.where(np.diff(saturation, prepend=False, append=False))[0] | ||
| n_events = len(intervals) // 2 # Number of saturation periods | ||
| events = np.zeros(n_events, dtype=EVENT_VECTOR_TYPE) | ||
|
|
||
| for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): | ||
| events[i]["start_sample_index"] = start + start_frame | ||
| events[i]["stop_sample_index"] = stop + start_frame | ||
| events[i]["segment_index"] = segment_index | ||
| events[i]["method_id"] = "saturation_detection" | ||
|
|
||
| # Because we inherit PeakDetector, we must expose this "sample_index" | ||
| # array. However, it is not used and changing the value has no effect. | ||
| toto = np.array([0], dtype=[("sample_index", "int64")]) | ||
|
|
||
| return (toto, events) | ||
|
|
||
|
|
||
| def detect_saturation( | ||
| recording, | ||
| saturation_threshold, # 1200 uV | ||
| voltage_per_sec_threshold, # 1e-8 V.s-1 | ||
| proportion=0.5, | ||
| mute_window_samples=7, | ||
| job_kwargs=None, | ||
| ): | ||
| """ """ | ||
| if job_kwargs: | ||
| job_kwargs = {} | ||
|
|
||
| job_kwargs = fix_job_kwargs(job_kwargs) | ||
|
|
||
| node0 = _DetectSaturation( | ||
| recording, | ||
| saturation_threshold=saturation_threshold, | ||
| voltage_per_sec_threshold=voltage_per_sec_threshold, | ||
| proportion=proportion, | ||
| mute_window_samples=mute_window_samples, | ||
| ) | ||
|
|
||
| _, events = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation events") | ||
|
|
||
| return _collapse_events(events) | ||
82 changes: 82 additions & 0 deletions
82
src/spikeinterface/preprocessing/tests/test_detect_saturation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| import numpy as np | ||
| import scipy.signal | ||
| import pandas as pd | ||
| from spikeinterface.core.numpyextractors import NumpyRecording | ||
| from spikeinterface.preprocessing.detect_saturation import detect_saturation | ||
|
|
||
| # TODO: add pre-sets and document? or at least reccomend values in documentation probably easier | ||
|
|
||
|
|
||
| def test_saturation_detection(): | ||
| """ | ||
| TODO: NOTE: we have one sample before the saturation starts as we take the forward derivative for the velocity | ||
| we have an extra sample after due to taking the diff on the final saturation mask | ||
| this means we always take one sample before and one sample after the diff period, which is fine. | ||
| """ | ||
| sample_frequency = 30000 | ||
| chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below | ||
| job_kwargs = {"chunk_size": chunk_size} | ||
|
|
||
| # cross a chunk boundary. Do not change without changing the below. | ||
| sat_value = 1200 * 1e-6 | ||
| data = np.random.uniform(low=-0.5, high=0.5, size=(150000, 384)) * 10 * 1e-6 | ||
|
|
||
| # Design the Butterworth filter | ||
| sos = scipy.signal.butter(N=3, Wn=12000 / (sample_frequency / 2), btype="low", output="sos") | ||
|
|
||
| # Apply the filter to the data | ||
| data_seg_1 = scipy.signal.sosfiltfilt(sos, data, axis=0) | ||
| data_seg_2 = data_seg_1.copy() | ||
|
|
||
| # Add test saturation at the start, end of recording | ||
| # as well as across and within chunks (30k samples). | ||
| # Two cases which are not tested are a single event | ||
| # exactly on the border, as it makes testing complex | ||
| # This was checked manually and any future breaking change | ||
| # on this function would be extremely unlikely only to break this case. | ||
| # fmt:off | ||
| all_starts = np.array([0, 29950, 45123, 90005, 149500]) | ||
| all_stops = np.array([1000, 30010, 45125, 90005, 149998]) | ||
| # fmt:on | ||
|
|
||
| second_seg_offset = 1 | ||
| for start, stop in zip(all_starts, all_stops): | ||
| if start == stop: | ||
| data_seg_1[start] = sat_value | ||
| else: | ||
| data_seg_1[start : stop + 1, :] = sat_value | ||
| # differentiate the second segment for testing purposes | ||
| data_seg_2[start : stop + 1 + second_seg_offset, :] = sat_value | ||
|
|
||
| recording = NumpyRecording([data_seg_1, data_seg_2], sample_frequency) | ||
|
|
||
| events = detect_saturation( | ||
| recording, saturation_threshold=sat_value * 0.98, voltage_per_sec_threshold=1e-8, job_kwargs=job_kwargs | ||
| ) | ||
|
|
||
| seg_1_events = events[np.where(events["segment_index"] == 0)] | ||
| seg_2_events = events[np.where(events["segment_index"] == 1)] | ||
|
|
||
| # For the start times, all are one sample before the actual saturated | ||
| # period starts because the derivative threshold is exceeded at one | ||
| # sample before the saturation starts. Therefore this one-sample-offset | ||
| # on the start times is an implicit test that the derivative | ||
| # threshold is working properly. | ||
| for seg_events in [seg_1_events, seg_2_events]: | ||
| assert seg_events["start_sample_index"][0] == all_starts[0] | ||
| assert np.array_equal(seg_events["start_sample_index"][1:], np.array(all_starts)[1:] - 1) | ||
|
|
||
| assert np.array_equal(seg_1_events["stop_sample_index"], np.array(all_stops) + 1) | ||
| assert np.array_equal(seg_2_events["stop_sample_index"], np.array(all_stops) + 1 + second_seg_offset) | ||
|
|
||
| # Just do a quick test that a threshold slightly over the sat value is not detected. | ||
| # In this case we only see the derivative threshold detection. We do not play around with this | ||
| # threshold because the derivative threshold is not easy to predict (the baseline sample is random). | ||
| events = detect_saturation( | ||
| recording, | ||
| saturation_threshold=sat_value * (1.0 / 0.98) + 1e-6, | ||
| voltage_per_sec_threshold=1e-8, | ||
| job_kwargs=job_kwargs, | ||
| ) | ||
| assert events["start_sample_index"][0] == 1000 | ||
| assert events["stop_sample_index"][0] == 1001 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.