Skip to content

Commit 443810f

Browse files
authored
Merge pull request #327 from EIT-ALIVE/326-add-multiple-digital-notch-filter
Add Multiple Digital Notch filter
2 parents ed46b39 + 5295c00 commit 443810f

File tree

12 files changed

+1020
-10
lines changed

12 files changed

+1020
-10
lines changed
File renamed without changes.

docs/api/filters/mdn.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
::: eitprocessing.filters.mdn.MDNFilter

docs/api/plotting/filter.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
::: eitprocessing.plotting.filter.FilterPlotting

eitprocessing/filters/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
from abc import ABC, abstractmethod
2-
from typing import NoReturn
2+
from typing import TypeVar
33

4-
import numpy.typing as npt
4+
import numpy as np
5+
6+
from eitprocessing.datahandling.eitdata import EITData
7+
from tests.test_breath_detection import ContinuousData
8+
9+
T = TypeVar("T", bound=np.ndarray | ContinuousData | EITData)
510

611

712
class TimeDomainFilter(ABC):
@@ -10,6 +15,6 @@ class TimeDomainFilter(ABC):
1015
available_in_gui = True
1116

1217
@abstractmethod
13-
def apply_filter(self, input_data: npt.ArrayLike) -> NoReturn:
18+
def apply(self, input_data: T, **kwargs) -> T:
1419
"""Apply the filter to the input data."""
1520
...

eitprocessing/filters/butterworth_filters.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import sys
2+
import warnings
23
from dataclasses import InitVar, dataclass
34
from typing import Literal
45

56
import numpy as np
6-
import numpy.typing as npt
77
from scipy import signal
88

99
from eitprocessing.filters import TimeDomainFilter
10+
from eitprocessing.utils import make_capture
1011

1112

1213
@dataclass(kw_only=True)
@@ -131,18 +132,40 @@ def _check_init(self, ignore_max_order: bool) -> None: # noqa:C901
131132
msg = f"Invalid `sample_frequency` ({self.sample_frequency}). Must be positive"
132133
raise ValueError(msg)
133134

134-
def apply_filter(self, input_data: npt.ArrayLike, axis: int = -1) -> np.ndarray:
135+
def apply_filter(self, *args, **kwargs) -> np.ndarray:
136+
"""Deprecated method. Use `apply()` instead."""
137+
warnings.warn("The `apply_filter` method is deprecated. Use `apply` instead.", DeprecationWarning, stacklevel=2)
138+
return self.apply(*args, **kwargs)
139+
140+
def apply(self, input_data: np.ndarray, axis: int = -1, captures: dict | None = None) -> np.ndarray:
135141
"""Apply the filter to the input data.
136142
137143
Args:
138144
input_data: Data to be filtered. If the input data has more than one axis,
139145
the filter is applied to the last axis.
140146
axis: Data axis the filter should be applied to. This defaults to the last axis,
141147
assuming this to be the time axis of the input data.
148+
captures:
149+
Optional. A dictionary to capture intermediate date, useful for plotting and debugging.
142150
143151
Returns:
144152
The filtered output with the same shape as the input data.
145153
"""
154+
capture = make_capture(captures)
155+
capture("unfiltered_data", input_data)
156+
capture("sample_frequency", self.sample_frequency)
157+
158+
match self.filter_type:
159+
case "lowpass":
160+
capture("low_pass_frequency", self.cutoff_frequency)
161+
case "highpass":
162+
capture("high_pass_frequency", self.cutoff_frequency)
163+
case "bandpass":
164+
capture("low_pass_frequency", self.cutoff_frequency[1])
165+
capture("high_pass_frequency", self.cutoff_frequency[0])
166+
case "bandstop":
167+
capture("frequency_bands", self.cutoff_frequency, append_to_list=True)
168+
146169
if np.any(np.isnan(input_data)):
147170
msg = "Input data contains NaN-values."
148171
exc = ValueError(msg)
@@ -161,7 +184,9 @@ def apply_filter(self, input_data: npt.ArrayLike, axis: int = -1) -> np.ndarray:
161184
output="sos",
162185
)
163186

164-
return signal.sosfiltfilt(sos, input_data, axis=axis)
187+
filtered_data = signal.sosfiltfilt(sos, input_data, axis=axis)
188+
capture("filtered_data", filtered_data)
189+
return filtered_data
165190

166191

167192
@dataclass(kw_only=True)

eitprocessing/filters/mdn.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import copy
2+
import math
3+
import warnings
4+
from dataclasses import dataclass
5+
from typing import TypeVar, cast, overload
6+
7+
import numpy as np
8+
from scipy import signal
9+
10+
from eitprocessing.datahandling.continuousdata import ContinuousData
11+
from eitprocessing.datahandling.eitdata import EITData
12+
from eitprocessing.filters import TimeDomainFilter
13+
from eitprocessing.plotting.filter import FilterPlotting
14+
from eitprocessing.utils import _CaptureFunc, make_capture
15+
16+
MINUTE = 60
17+
NOISE_FREQUENCY_LIMIT: float = 220 / MINUTE
18+
DEFAULT_AXIS: int = 0
19+
20+
# TODO: centralize settings (these should be shared with e.g. RateDetection)
21+
UPPER_RESPIRATORY_RATE_LIMIT: float = 85 / MINUTE
22+
UPPER_HEART_RATE_LIMIT: float = 210 / MINUTE
23+
24+
T = TypeVar("T", bound=np.ndarray | ContinuousData | EITData)
25+
26+
27+
MISSING = object()
28+
29+
30+
@dataclass(frozen=True, kw_only=True)
31+
class MDNFilter(TimeDomainFilter):
32+
"""Multiple Digital Notch filter.
33+
34+
This filter is used to remove heart rate noise from EIT data. A band stop filter removes heart rate ± the notch
35+
distance. This is repeated for every harmonic of the heart rate below the noise frequency limit. Lastly, a low pass
36+
filter removes noise above the noise frequency limit.
37+
38+
By default, the notch distance is set to 0.166... Hz (10 BPM), and the noise frequency limit is
39+
set to 3.66... Hz (220 BPM).
40+
41+
Warning:
42+
The respiratory and heart rate should be in provided Hz, not BPM. We recommend defining `MINUTE = 60` and using,
43+
e.g., `heart_rate=80 / MINUTE` to manually set the heart rate to 80 BPM.
44+
45+
Warning:
46+
This filter was designed to remove heart rate noise from EIT data, and testing in a limited number of cases. The
47+
filter may not work as expected for other data types, different cohorts or non-traditional ventilation modes.
48+
Use at your own discretion.
49+
50+
Args:
51+
respiratory_rate: the respiratory rate of the subject in Hz
52+
heart_rate: the heart rate of the subject in Hz
53+
noise_frequency_limit: the highest frequency to filter in Hz
54+
notch_distance: the half width of the band stop filter's frequency range
55+
"""
56+
57+
respiratory_rate: float
58+
heart_rate: float
59+
noise_frequency_limit: float = 220 / MINUTE
60+
notch_distance: float = 10 / MINUTE
61+
order: int = 10
62+
63+
def __post_init__(self):
64+
if self.respiratory_rate > UPPER_RESPIRATORY_RATE_LIMIT:
65+
msg = (
66+
f"The provided respiratory rate ({self.respiratory_rate:.1f}) "
67+
f"is higher than {UPPER_RESPIRATORY_RATE_LIMIT} Hz "
68+
f"({UPPER_RESPIRATORY_RATE_LIMIT * MINUTE} BPM). "
69+
"Make sure to use the correct unit (Hz, not BPM)."
70+
)
71+
warnings.warn(msg, UserWarning, stacklevel=2)
72+
if self.respiratory_rate <= 0:
73+
msg = f"The provided respiratory rate ({self.respiratory_rate:.2f}) must be positive."
74+
raise ValueError(msg)
75+
76+
if self.heart_rate > UPPER_HEART_RATE_LIMIT:
77+
msg = (
78+
f"The provided heart rate ({self.heart_rate:.1f}) is higher "
79+
f"than {UPPER_HEART_RATE_LIMIT} Hz ({UPPER_HEART_RATE_LIMIT * MINUTE} BPM). "
80+
"Make sure this is correct, and to use the correct unit."
81+
)
82+
warnings.warn(msg, UserWarning, stacklevel=2)
83+
if self.heart_rate <= 0:
84+
msg = f"The provided heart rate ({self.heart_rate:.2f}) must be positive."
85+
raise ValueError(msg)
86+
87+
if self.respiratory_rate >= self.heart_rate:
88+
msg = (
89+
f"The respiratory rate ({self.respiratory_rate:.1f} Hz) is equal to or higher than the heart "
90+
f"rate ({self.heart_rate:.1f} Hz)."
91+
)
92+
raise ValueError(msg)
93+
94+
@overload
95+
def apply(
96+
self, input_data: np.ndarray, sample_frequency: float, axis: int = 0, captures: dict | None = None
97+
) -> np.ndarray: ...
98+
99+
@overload
100+
def apply(self, input_data: ContinuousData, captures: dict | None = None, **kwargs) -> ContinuousData: ...
101+
102+
@overload
103+
def apply(self, input_data: EITData, captures: dict | None = None, **kwargs) -> EITData: ...
104+
105+
def apply( # pyright: ignore[reportInconsistentOverload]
106+
self,
107+
input_data: T,
108+
sample_frequency: float | object = MISSING,
109+
axis: int | object = MISSING,
110+
captures: dict | None = None,
111+
**kwargs,
112+
) -> T:
113+
"""Filter data using multiple digital notch filters.
114+
115+
Args:
116+
input_data: The data to filter. Can be a numpy array, ContinuousData, or EITData.
117+
sample_frequency:
118+
The sample frequency of the data. Should be provided when using a numpy array. If using
119+
ContinuousData or EITData, this will be taken from the data object.
120+
axis:
121+
The axis along which to apply the filter. Should only be provided when using a numpy array. Defaults to
122+
the first axis (0).
123+
captures:
124+
A dictionary to capture intermediate results for debugging or analysis. If provided, it will store the
125+
number of harmonics and the frequency bands used for filtering.
126+
**kwargs: Additional keyword arguments to pass to the ContinuousData or EITData object (e.g., `label`).
127+
"""
128+
capture = make_capture(captures)
129+
capture("low_pass_frequency", self.noise_frequency_limit)
130+
capture("unfiltered_data", input_data)
131+
132+
sample_frequency_, axis_, data = self._validate_arguments(
133+
input_data=input_data, sample_frequency=sample_frequency, axis=axis
134+
)
135+
136+
# Ensure the data is filtered up to the point where lower_limit would be larger than the noise frequency limit
137+
n_harmonics = math.floor((self.noise_frequency_limit + self.notch_distance) / self.heart_rate)
138+
capture("n_harmonics", n_harmonics)
139+
140+
for harmonic in range(1, n_harmonics + 1):
141+
data = self._filter_harmonic_with_bandstop(data, harmonic, axis_, sample_frequency_, capture)
142+
143+
# Filter everything above noise limit
144+
sos = signal.butter(
145+
N=self.order,
146+
Wn=self.noise_frequency_limit,
147+
fs=sample_frequency_,
148+
btype="low",
149+
output="sos",
150+
)
151+
new_data = signal.sosfiltfilt(sos, data, axis_)
152+
153+
if isinstance(input_data, np.ndarray):
154+
capture("filtered_data", new_data)
155+
return new_data
156+
157+
# TODO: Replace with input_data.update(...) when implemented
158+
return_object = copy.deepcopy(input_data)
159+
for attr, value in kwargs.items():
160+
setattr(return_object, attr, value)
161+
162+
if isinstance(return_object, ContinuousData):
163+
return_object.values = new_data
164+
elif isinstance(return_object, EITData):
165+
return_object.pixel_impedance = new_data
166+
167+
capture("filtered_data", return_object)
168+
return return_object
169+
170+
def _validate_arguments(
171+
self,
172+
input_data: np.ndarray | ContinuousData | EITData,
173+
sample_frequency: float | object,
174+
axis: int | object,
175+
) -> tuple[float, int, np.ndarray]:
176+
if isinstance(input_data, ContinuousData | EITData):
177+
if sample_frequency is not MISSING:
178+
msg = "Sample frequency should not be provided when using ContinuousData or EITData."
179+
raise ValueError(msg)
180+
181+
if axis is not MISSING:
182+
msg = "Axis should not be provided when using ContinuousData or EITData."
183+
raise ValueError(msg)
184+
185+
if isinstance(input_data, ContinuousData):
186+
data = input_data.values
187+
sample_frequency_ = cast("float", input_data.sample_frequency)
188+
axis_ = 0
189+
elif isinstance(input_data, EITData):
190+
data = input_data.pixel_impedance
191+
sample_frequency_ = cast("float", input_data.sample_frequency)
192+
axis_ = 0
193+
elif isinstance(input_data, np.ndarray):
194+
data = input_data
195+
axis_ = DEFAULT_AXIS if axis is MISSING else axis
196+
axis_ = cast("int", axis_)
197+
if sample_frequency is MISSING:
198+
msg = "Sample frequency must be provided when using a numpy array."
199+
raise ValueError(msg)
200+
sample_frequency_: float = cast("float", sample_frequency)
201+
else:
202+
msg = f"Invalid input data type ({type(input_data)}). Must be a numpy array, ContinuousData, or EITData."
203+
raise TypeError(msg)
204+
205+
if not sample_frequency_:
206+
msg = "Sample frequency must be provided."
207+
raise ValueError(msg)
208+
return sample_frequency_, axis_, data
209+
210+
def _filter_harmonic_with_bandstop(
211+
self,
212+
data_: np.ndarray,
213+
harmonic: int,
214+
axis: int,
215+
sample_frequency: float,
216+
capture: _CaptureFunc,
217+
) -> np.ndarray:
218+
lower_limit = self.heart_rate * harmonic - self.notch_distance
219+
upper_limit = self.heart_rate * harmonic + self.notch_distance
220+
221+
if harmonic == 1:
222+
new_lower_limit = (self.heart_rate + self.respiratory_rate) / 2
223+
lower_limit = max(lower_limit, new_lower_limit)
224+
225+
sos = signal.butter(
226+
N=self.order,
227+
Wn=[lower_limit, upper_limit],
228+
fs=sample_frequency,
229+
btype="bandstop",
230+
output="sos",
231+
)
232+
233+
capture("frequency_bands", (lower_limit, upper_limit), append_to_list=True)
234+
235+
return signal.sosfiltfilt(sos, data_, axis=axis)
236+
237+
@property
238+
def plotting(self) -> FilterPlotting:
239+
"""Return the plotting class for this filter."""
240+
from eitprocessing.plotting.filter import FilterPlotting
241+
242+
return FilterPlotting()

0 commit comments

Comments
 (0)