Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class FilterRecording(BasePreprocessor):
btype : "bandpass" | "highpass", default: "bandpass"
Type of the filter
margin_ms : float, default: 5.0
Margin in ms on border to avoid border effect
Margin in ms on border to avoid border effect.
coeff : array | None, default: None
Filter coefficients in the filter_mode form.
dtype : dtype or None, default: None
Expand Down Expand Up @@ -78,7 +78,8 @@ def __init__(
filter_order=5,
ftype="butter",
filter_mode="sos",
margin_ms=5.0,
margin_ms=5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
margin_ms=5,
margin_ms=5.0,

max_margin_s=5,
add_reflect_padding=False,
coeff=None,
dtype=None,
Expand Down Expand Up @@ -138,6 +139,22 @@ def __init__(
direction=direction,
)

def adjust_margin_ms_for_highpass(self, freq_min, max_margin_s):
# compute margin as 3 times the period of the highpass cutoff
margin_ms = 3 * (1000.0 / freq_min)
# limit max margin
max_margin_ms = max_margin_s * 1000.0
if margin_ms > max_margin_ms:
margin_ms = max_margin_ms
return margin_ms

def adjust_margin_ms_for_notch(self, max_margin_s, q, f0):
margin_ms = (3 / np.pi) * (q / f0) * 1000.0
max_margin_ms = max_margin_s * 1000.0
if margin_ms < max_margin_ms:
margin_ms = max_margin_ms
return margin_ms


class FilterRecordingSegment(BasePreprocessorSegment):
def __init__(
Expand Down Expand Up @@ -217,8 +234,11 @@ class BandpassFilterRecording(FilterRecording):
The highpass cutoff frequency in Hz
freq_max : float
The lowpass cutoff frequency in Hz
margin_ms : float
Margin in ms on border to avoid border effect
margin_ms : float | str, default: "auto"
Margin in ms on border to avoid border effect.
If "auto", margin is computed as 3 times the filter highpass cutoff period.
max_margin_s : float, default: 5
Maximum margin in seconds when margin_ms is set to "auto".
dtype : dtype or None
The dtype of the returned traces. If None, the dtype of the parent recording is used
{}
Expand All @@ -229,7 +249,11 @@ class BandpassFilterRecording(FilterRecording):
The bandpass-filtered recording extractor object
"""

def __init__(self, recording, freq_min=300.0, freq_max=6000.0, margin_ms=5.0, dtype=None, **filter_kwargs):
def __init__(
self, recording, freq_min=300.0, freq_max=6000.0, margin_ms="auto", max_margin_s=5, dtype=None, **filter_kwargs
):
if margin_ms == "auto":
margin_ms = self.adjust_margin_ms_for_highpass(freq_min, max_margin_s)
FilterRecording.__init__(
self, recording, band=[freq_min, freq_max], margin_ms=margin_ms, dtype=dtype, **filter_kwargs
)
Expand All @@ -250,8 +274,11 @@ class HighpassFilterRecording(FilterRecording):
The recording extractor to be re-referenced
freq_min : float
The highpass cutoff frequency in Hz
margin_ms : float
Margin in ms on border to avoid border effect
margin_ms : float | str, default: "auto"
Margin in ms on border to avoid border effect.
If "auto", margin is computed as 3 times the filter highpass cutoff period.
max_margin_s : float, default: 5
Maximum margin in seconds when margin_ms is set to "auto".
dtype : dtype or None
The dtype of the returned traces. If None, the dtype of the parent recording is used
{}
Expand All @@ -262,7 +289,9 @@ class HighpassFilterRecording(FilterRecording):
The highpass-filtered recording extractor object
"""

def __init__(self, recording, freq_min=300.0, margin_ms=5.0, dtype=None, **filter_kwargs):
def __init__(self, recording, freq_min=300.0, margin_ms="auto", max_margin_s=5, dtype=None, **filter_kwargs):
if margin_ms == "auto":
margin_ms = self.adjust_margin_ms_for_highpass(freq_min, max_margin_s)
FilterRecording.__init__(
self, recording, band=freq_min, margin_ms=margin_ms, dtype=dtype, btype="highpass", **filter_kwargs
)
Expand All @@ -271,7 +300,7 @@ def __init__(self, recording, freq_min=300.0, margin_ms=5.0, dtype=None, **filte
self._kwargs.update(filter_kwargs)


class NotchFilterRecording(BasePreprocessor):
class NotchFilterRecording(FilterRecording):
"""
Parameters
----------
Expand All @@ -283,25 +312,27 @@ class NotchFilterRecording(BasePreprocessor):
The quality factor of the notch filter
dtype : None | dtype, default: None
dtype of recording. If None, will take from `recording`
margin_ms : float, default: 5.0
margin_ms : float | str, default: "auto"
Margin in ms on border to avoid border effect
max_margin_s : float, default: 5
Maximum margin in seconds when margin_ms is set to "auto".

Returns
-------
filter_recording : NotchFilterRecording
The notch-filtered recording extractor object
"""

def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):
# coeef is 'ba' type
fn = 0.5 * float(recording.get_sampling_frequency())
def __init__(self, recording, freq=3000, q=30, margin_ms="auto", max_margin_s=5, dtype=None, **filter_kwargs):
import scipy.signal

if margin_ms == "auto":
margin_ms = self.adjust_margin_ms_for_notch(max_margin_s, q, freq)

fn = 0.5 * float(recording.get_sampling_frequency())
coeff = scipy.signal.iirnotch(freq / fn, q)

if dtype is None:
dtype = recording.get_dtype()
dtype = np.dtype(dtype)
dtype = fix_dtype(recording, dtype)

# if uint --> unsupported
if dtype.kind == "u":
Expand All @@ -310,15 +341,12 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):
"to specify a signed type (e.g. 'int16', 'float32')"
)

BasePreprocessor.__init__(self, recording, dtype=dtype)
FilterRecording.__init__(
self, recording, coeff=coeff, filter_mode="ba", margin_ms=margin_ms, dtype=dtype, **filter_kwargs
)
self.annotate(is_filtered=True)

sf = recording.get_sampling_frequency()
margin = int(margin_ms * sf / 1000.0)
for parent_segment in recording._recording_segments:
self.add_recording_segment(FilterRecordingSegment(parent_segment, coeff, "ba", margin, dtype))

self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)
self._kwargs.update(filter_kwargs)


# functions for API
Expand Down