diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index de25944bd2..d2d8674168 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -20,6 +20,7 @@ PreprocessingPipeline, ) +from .detect_artifacts import detect_artifact_periods, detect_period_artifacts_by_envelope # for snippets from .align_snippets import AlignSnippets from warnings import warn diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py new file mode 100644 index 0000000000..6cb22ac49f --- /dev/null +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import numpy as np + +from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording +from spikeinterface.preprocessing.rectify import RectifyRecording +from spikeinterface.preprocessing.common_reference import CommonReferenceRecording +from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording +from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype +import numpy as np + + +artifact_dtype = [ + ("start_index", "int64"), + ("stop_index", "int64"), + ("segment_index", "int64"), +] + +extended_artifact_dtype = artifact_dtype + [ + # TODO +] + + +_internal_dtype = [ + ("sample_index", "int64"), + ("segment_index", "int64"), + ("front", "bool") +] + + +def detect_artifact_periods( + recording, + method="envelope", + method_kwargs=None, + job_kwargs=None, +): + """ + + """ + + if method_kwargs is None: + method_kwargs = dict() + + if method == "envelope": + artifacts, envelope = detect_period_artifacts_by_envelope(recording, **method_kwargs, job_kwargs=job_kwargs) + elif method == "saturation": + raise NotImplementedError("Soon") + + else: + raise ValueError("") + + return artifacts + + + +## detect_period_artifacts_saturation Zone + + + + +## detect_period_artifacts_by_envelope Zone + +class DetectThresholdCrossing(PeakDetector): + + name = "threshold_crossings" + preferred_mp_context = None + + def __init__( + self, + recording, + detect_threshold=5, + noise_levels=None, + seed=None, + noise_levels_kwargs=dict(), + ): + PeakDetector.__init__(self, recording, return_output=True) + if noise_levels is None: + random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() + random_slices_kwargs["seed"] = seed + noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) + self.abs_thresholds = noise_levels * detect_threshold + self._dtype = np.dtype(_internal_dtype) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return self._dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + z = np.median(traces / self.abs_thresholds, 1) + threshold_mask = np.diff((z > 1) != 0, axis=0) + indices = np.flatnonzero(threshold_mask) + threshold_crossings = np.zeros(indices.size, dtype=self._dtype) + threshold_crossings["sample_index"] = indices + threshold_crossings["segment_index"] = segment_index + threshold_crossings["front"][::2] = True + threshold_crossings["front"][1::2] = False + return (threshold_crossings,) + + +def detect_period_artifacts_by_envelope( + recording, + detect_threshold=5, + # min_duration_ms=50, + freq_max=20.0, + seed=None, + job_kwargs=None, + random_slices_kwargs=None, +): + """ + Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of + a global envelope of the channels. + + Parameters + ---------- + recording : RecordingExtractor + The recording extractor to detect putative artifacts + detect_threshold : float, default: 5 + The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` + freq_max : float, default: 20 + The maximum frequency for the low pass filter used + seed : int | None, default: None + Random seed for `get_noise_levels`. + If none, `get_noise_levels` uses `seed=0`. + **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function + + """ + + envelope = RectifyRecording(recording) + envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) + envelope = CommonReferenceRecording(envelope) + + from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ) + + # _, job_kwargs = split_job_kwargs(noise_levels_kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) + if random_slices_kwargs is None: + random_slices_kwargs = {} + else: + random_slices_kwargs = random_slices_kwargs.copy() + random_slices_kwargs["seed"] = seed + noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) + + node0 = DetectThresholdCrossing( + recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, + ) + + threshold_crossings = run_node_pipeline( + envelope, + [node0], + job_kwargs, + job_name="detect threshold crossings", + ) + + order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) + threshold_crossings = threshold_crossings[order] + + artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) + + + return artifacts, envelope + + +# tools + +def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): + + num_seg = recording.get_num_segments() + + final_artifacts = [] + for seg_index in range(num_seg): + mask = artifacts["segment_index"] == seg_index + sub_thr = artifacts[mask] + if len(sub_thr) > 0: + if not sub_thr["front"][0]: + local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr["sample_index"] = 0 + local_thr["front"] = True + sub_thr = np.hstack((local_thr, sub_thr)) + if sub_thr["front"][-1]: + local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr["sample_index"] = recording.get_num_samples(seg_index) + local_thr["front"] = False + sub_thr = np.hstack((sub_thr, local_thr)) + + local_artifact = np.zeros(sub_thr.size/2, dtype=artifact_dtype) + local_artifact["start_index"] = sub_thr["sample_index"][::2] + local_artifact["stop_index"] = sub_thr["sample_index"][1::2] + local_artifact["segment_index"] = seg_index + final_artifacts.append(local_artifact) + + if len(final_artifacts) > 0: + final_artifacts = np.concatenate(final_artifacts) + else: + final_artifacts = np.zeros(0, dtype=artifact_dtype) + return final_artifacts \ No newline at end of file diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index fe9d95c506..47839db7a0 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -50,7 +50,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed -from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts +# from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts _all_preprocesser_dict = { # filter stuff @@ -90,7 +90,7 @@ DirectionalDerivativeRecording: directional_derivative, AstypeRecording: astype, UnsignedToSignedRecording: unsigned_to_signed, - SilencedArtifactsRecording: silence_artifacts, + # SilencedArtifactsRecording: silence_artifacts, } # we control import in the preprocessing init by setting an __all__ diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index b1ae00b64c..8006342847 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -4,221 +4,89 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording -from spikeinterface.preprocessing.rectify import RectifyRecording -from spikeinterface.preprocessing.common_reference import CommonReferenceRecording -from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording -from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype import numpy as np -class DetectThresholdCrossing(PeakDetector): - - name = "threshold_crossings" - preferred_mp_context = None - - def __init__( - self, - recording, - detect_threshold=5, - noise_levels=None, - seed=None, - noise_levels_kwargs=dict(), - ): - PeakDetector.__init__(self, recording, return_output=True) - if noise_levels is None: - random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() - random_slices_kwargs["seed"] = seed - noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) - self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("front", "bool")]) - - def get_trace_margin(self): - return 0 - - def get_dtype(self): - return self._dtype - - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = np.median(traces / self.abs_thresholds, 1) - threshold_mask = np.diff((z > 1) != 0, axis=0) - indices = np.flatnonzero(threshold_mask) - threshold_crossings = np.zeros(indices.size, dtype=self._dtype) - threshold_crossings["sample_index"] = indices - threshold_crossings["front"][::2] = True - threshold_crossings["front"][1::2] = False - return (threshold_crossings,) - - -def detect_period_artifacts_by_envelope( - recording, - detect_threshold=5, - min_duration_ms=50, - freq_max=20.0, - seed=None, - noise_levels=None, - **noise_levels_kwargs, -): - """ - Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of - a global envelope of the channels. - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to detect putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels`. - If none, `get_noise_levels` uses `seed=0`. - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - """ - - envelope = RectifyRecording(recording) - envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) - envelope = CommonReferenceRecording(envelope) - - from spikeinterface.core.node_pipeline import ( - run_node_pipeline, - ) - - _, job_kwargs = split_job_kwargs(noise_levels_kwargs) - job_kwargs = fix_job_kwargs(job_kwargs) - - node0 = DetectThresholdCrossing( - recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - threshold_crossings = run_node_pipeline( - recording, - [node0], - job_kwargs, - job_name="detect threshold crossings", - ) - - order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) - threshold_crossings = threshold_crossings[order] - - periods = [] - fs = recording.sampling_frequency - max_duration_samples = int(min_duration_ms * fs / 1000) - num_seg = recording.get_num_segments() - - for seg_index in range(num_seg): - sub_periods = [] - mask = threshold_crossings["segment_index"] == seg_index - sub_thr = threshold_crossings[mask] - if len(sub_thr) > 0: - local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) - if not sub_thr["front"][0]: - local_thr["sample_index"] = 0 - local_thr["front"] = True - sub_thr = np.hstack((local_thr, sub_thr)) - if sub_thr["front"][-1]: - local_thr["sample_index"] = recording.get_num_samples(seg_index) - local_thr["front"] = False - sub_thr = np.hstack((sub_thr, local_thr)) - - indices = np.flatnonzero(np.diff(sub_thr["front"])) - for i, j in zip(indices[:-1], indices[1:]): - if sub_thr["front"][i]: - start = sub_thr["sample_index"][i] - end = sub_thr["sample_index"][j] - if end - start > max_duration_samples: - sub_periods.append((start, end)) - - periods.append(sub_periods) - - return periods, envelope - - -class SilencedArtifactsRecording(SilencedPeriodsRecording): - """ - Silence user-defined periods from recording extractor traces. The code will construct - an enveloppe of the recording (as a low pass filtered version of the traces) and detect - threshold crossings to identify the periods to silence. The periods are then silenced either - on a per channel basis or across all channels by replacing the values by zeros or by - adding gaussian noise with the same variance as the one in the recordings - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to silence putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. - If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise", default: "zeros" - Determines what periods are replaced by. Can be one of the following: - - - "zeros": Artifacts are replaced by zeros. - - - "noise": The periods are filled with a gaussion noise that has the - same variance that the one in the recordings, on a per channel - basis - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - Returns - ------- - silenced_recording : SilencedArtifactsRecording - The recording extractor after silencing detected artifacts - """ - - _precomputable_kwarg_names = ["list_periods"] - - def __init__( - self, - recording, - detect_threshold=5, - verbose=False, - freq_max=20.0, - min_duration_ms=50, - mode="zeros", - noise_levels=None, - seed=None, - list_periods=None, - **noise_levels_kwargs, - ): - - if list_periods is None: - list_periods, _ = detect_period_artifacts_by_envelope( - recording, - detect_threshold=detect_threshold, - min_duration_ms=min_duration_ms, - freq_max=freq_max, - seed=seed, - noise_levels=noise_levels, - **noise_levels_kwargs, - ) - - if verbose: - for i, periods in enumerate(list_periods): - total_time = np.sum([end - start for start, end in periods]) - percentage = 100 * total_time / recording.get_num_samples(i) - print(f"{percentage}% of segment {i} has been flagged as artifactual") - - SilencedPeriodsRecording.__init__( - self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - -# function for API -silence_artifacts = define_function_handling_dict_from_class( - source_class=SilencedArtifactsRecording, name="silence_artifacts" -) +# class SilencedArtifactsRecording(SilencedPeriodsRecording): +# """ +# Silence user-defined periods from recording extractor traces. The code will construct +# an enveloppe of the recording (as a low pass filtered version of the traces) and detect +# threshold crossings to identify the periods to silence. The periods are then silenced either +# on a per channel basis or across all channels by replacing the values by zeros or by +# adding gaussian noise with the same variance as the one in the recordings + +# Parameters +# ---------- +# recording : RecordingExtractor +# The recording extractor to silence putative artifacts +# artifacts : np.array, None +# The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` +# freq_max : float, default: 20 +# The maximum frequency for the low pass filter used +# min_duration_ms : float, default: 50 +# The minimum duration for a threshold crossing to be considered as an artefact. +# noise_levels : array +# Noise levels if already computed +# seed : int | None, default: None +# Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. +# If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. +# mode : "zeros" | "noise", default: "zeros" +# Determines what periods are replaced by. Can be one of the following: + +# - "zeros": Artifacts are replaced by zeros. + +# - "noise": The periods are filled with a gaussion noise that has the +# same variance that the one in the recordings, on a per channel +# basis +# **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function + +# Returns +# ------- +# silenced_recording : SilencedArtifactsRecording +# The recording extractor after silencing detected artifacts +# """ + +# _precomputable_kwarg_names = ["artifacts"] + +# def __init__( +# self, +# recording, +# artifacts=None, +# detect_threshold=5, +# verbose=False, +# freq_max=20.0, +# min_duration_ms=50, +# mode="zeros", +# noise_levels=None, +# seed=None, +# list_periods=None, +# **noise_levels_kwargs, +# ): + +# if artifacts is None: +# from spikeinterface.preprocessing import detect_artifacts +# artifacts = detect_artifact_periods( +# recording, +# detect_threshold=detect_threshold, +# min_duration_ms=min_duration_ms, +# freq_max=freq_max, +# seed=seed, +# noise_levels=noise_levels, +# **noise_levels_kwargs, +# ) + +# if verbose: +# for i, periods in enumerate(artifacts): +# total_time = np.sum([end - start for start, end in periods]) +# percentage = 100 * total_time / recording.get_num_samples(i) +# print(f"{percentage}% of segment {i} has been flagged as artifactual") + +# SilencedPeriodsRecording.__init__( +# self, recording, artifacts, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs +# ) + + +# # function for API +# silence_artifacts = define_function_handling_dict_from_class( +# source_class=SilencedArtifactsRecording, name="silence_artifacts" +# ) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py new file mode 100644 index 0000000000..52e8d927f9 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -0,0 +1,13 @@ +from spikeinterface.core import generate_recording +from spikeinterface.preprocessing import detect_artifact_periods + + +def test_detect_artifact_periods(): + # one segment only + rec = generate_recording(durations=[10.0, 10]) + artifacts = detect_artifact_periods(rec, method="envelope", + method_kwargs=dict(detect_threshold=5, freq_max=5.0), + ) + +if __name__ == "__main__": + test_detect_artifact_periods() diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py index 2baa4bf1b3..ad70540f40 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py @@ -2,15 +2,15 @@ import numpy as np -from spikeinterface.core import generate_recording -from spikeinterface.preprocessing import silence_artifacts +# from spikeinterface.core import generate_recording +# from spikeinterface.preprocessing import silence_artifacts -def test_silence_artifacts(): - # one segment only - rec = generate_recording(durations=[10.0, 10]) - new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) +# def test_silence_artifacts(): +# # one segment only +# rec = generate_recording(durations=[10.0, 10]) +# new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) -if __name__ == "__main__": - test_silence_artifacts() +# if __name__ == "__main__": +# test_silence_artifacts()