|
| 1 | +import copy |
| 2 | +import math |
| 3 | +import warnings |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import TypeVar, cast, overload |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from scipy import signal |
| 9 | + |
| 10 | +from eitprocessing.datahandling.continuousdata import ContinuousData |
| 11 | +from eitprocessing.datahandling.eitdata import EITData |
| 12 | +from eitprocessing.filters import TimeDomainFilter |
| 13 | +from eitprocessing.plotting.filter import FilterPlotting |
| 14 | +from eitprocessing.utils import _CaptureFunc, make_capture |
| 15 | + |
| 16 | +MINUTE = 60 |
| 17 | +NOISE_FREQUENCY_LIMIT: float = 220 / MINUTE |
| 18 | +DEFAULT_AXIS: int = 0 |
| 19 | + |
| 20 | +# TODO: centralize settings (these should be shared with e.g. RateDetection) |
| 21 | +UPPER_RESPIRATORY_RATE_LIMIT: float = 85 / MINUTE |
| 22 | +UPPER_HEART_RATE_LIMIT: float = 210 / MINUTE |
| 23 | + |
| 24 | +T = TypeVar("T", bound=np.ndarray | ContinuousData | EITData) |
| 25 | + |
| 26 | + |
| 27 | +MISSING = object() |
| 28 | + |
| 29 | + |
| 30 | +@dataclass(frozen=True, kw_only=True) |
| 31 | +class MDNFilter(TimeDomainFilter): |
| 32 | + """Multiple Digital Notch filter. |
| 33 | +
|
| 34 | + This filter is used to remove heart rate noise from EIT data. A band stop filter removes heart rate ± the notch |
| 35 | + distance. This is repeated for every harmonic of the heart rate below the noise frequency limit. Lastly, a low pass |
| 36 | + filter removes noise above the noise frequency limit. |
| 37 | +
|
| 38 | + By default, the notch distance is set to 0.166... Hz (10 BPM), and the noise frequency limit is |
| 39 | + set to 3.66... Hz (220 BPM). |
| 40 | +
|
| 41 | + Warning: |
| 42 | + The respiratory and heart rate should be in provided Hz, not BPM. We recommend defining `MINUTE = 60` and using, |
| 43 | + e.g., `heart_rate=80 / MINUTE` to manually set the heart rate to 80 BPM. |
| 44 | +
|
| 45 | + Warning: |
| 46 | + This filter was designed to remove heart rate noise from EIT data, and testing in a limited number of cases. The |
| 47 | + filter may not work as expected for other data types, different cohorts or non-traditional ventilation modes. |
| 48 | + Use at your own discretion. |
| 49 | +
|
| 50 | + Args: |
| 51 | + respiratory_rate: the respiratory rate of the subject in Hz |
| 52 | + heart_rate: the heart rate of the subject in Hz |
| 53 | + noise_frequency_limit: the highest frequency to filter in Hz |
| 54 | + notch_distance: the half width of the band stop filter's frequency range |
| 55 | + """ |
| 56 | + |
| 57 | + respiratory_rate: float |
| 58 | + heart_rate: float |
| 59 | + noise_frequency_limit: float = 220 / MINUTE |
| 60 | + notch_distance: float = 10 / MINUTE |
| 61 | + order: int = 10 |
| 62 | + |
| 63 | + def __post_init__(self): |
| 64 | + if self.respiratory_rate > UPPER_RESPIRATORY_RATE_LIMIT: |
| 65 | + msg = ( |
| 66 | + f"The provided respiratory rate ({self.respiratory_rate:.1f}) " |
| 67 | + f"is higher than {UPPER_RESPIRATORY_RATE_LIMIT} Hz " |
| 68 | + f"({UPPER_RESPIRATORY_RATE_LIMIT * MINUTE} BPM). " |
| 69 | + "Make sure to use the correct unit (Hz, not BPM)." |
| 70 | + ) |
| 71 | + warnings.warn(msg, UserWarning, stacklevel=2) |
| 72 | + if self.respiratory_rate <= 0: |
| 73 | + msg = f"The provided respiratory rate ({self.respiratory_rate:.2f}) must be positive." |
| 74 | + raise ValueError(msg) |
| 75 | + |
| 76 | + if self.heart_rate > UPPER_HEART_RATE_LIMIT: |
| 77 | + msg = ( |
| 78 | + f"The provided heart rate ({self.heart_rate:.1f}) is higher " |
| 79 | + f"than {UPPER_HEART_RATE_LIMIT} Hz ({UPPER_HEART_RATE_LIMIT * MINUTE} BPM). " |
| 80 | + "Make sure this is correct, and to use the correct unit." |
| 81 | + ) |
| 82 | + warnings.warn(msg, UserWarning, stacklevel=2) |
| 83 | + if self.heart_rate <= 0: |
| 84 | + msg = f"The provided heart rate ({self.heart_rate:.2f}) must be positive." |
| 85 | + raise ValueError(msg) |
| 86 | + |
| 87 | + if self.respiratory_rate >= self.heart_rate: |
| 88 | + msg = ( |
| 89 | + f"The respiratory rate ({self.respiratory_rate:.1f} Hz) is equal to or higher than the heart " |
| 90 | + f"rate ({self.heart_rate:.1f} Hz)." |
| 91 | + ) |
| 92 | + raise ValueError(msg) |
| 93 | + |
| 94 | + @overload |
| 95 | + def apply( |
| 96 | + self, input_data: np.ndarray, sample_frequency: float, axis: int = 0, captures: dict | None = None |
| 97 | + ) -> np.ndarray: ... |
| 98 | + |
| 99 | + @overload |
| 100 | + def apply(self, input_data: ContinuousData, captures: dict | None = None, **kwargs) -> ContinuousData: ... |
| 101 | + |
| 102 | + @overload |
| 103 | + def apply(self, input_data: EITData, captures: dict | None = None, **kwargs) -> EITData: ... |
| 104 | + |
| 105 | + def apply( # pyright: ignore[reportInconsistentOverload] |
| 106 | + self, |
| 107 | + input_data: T, |
| 108 | + sample_frequency: float | object = MISSING, |
| 109 | + axis: int | object = MISSING, |
| 110 | + captures: dict | None = None, |
| 111 | + **kwargs, |
| 112 | + ) -> T: |
| 113 | + """Filter data using multiple digital notch filters. |
| 114 | +
|
| 115 | + Args: |
| 116 | + input_data: The data to filter. Can be a numpy array, ContinuousData, or EITData. |
| 117 | + sample_frequency: |
| 118 | + The sample frequency of the data. Should be provided when using a numpy array. If using |
| 119 | + ContinuousData or EITData, this will be taken from the data object. |
| 120 | + axis: |
| 121 | + The axis along which to apply the filter. Should only be provided when using a numpy array. Defaults to |
| 122 | + the first axis (0). |
| 123 | + captures: |
| 124 | + A dictionary to capture intermediate results for debugging or analysis. If provided, it will store the |
| 125 | + number of harmonics and the frequency bands used for filtering. |
| 126 | + **kwargs: Additional keyword arguments to pass to the ContinuousData or EITData object (e.g., `label`). |
| 127 | + """ |
| 128 | + capture = make_capture(captures) |
| 129 | + capture("low_pass_frequency", self.noise_frequency_limit) |
| 130 | + capture("unfiltered_data", input_data) |
| 131 | + |
| 132 | + sample_frequency_, axis_, data = self._validate_arguments( |
| 133 | + input_data=input_data, sample_frequency=sample_frequency, axis=axis |
| 134 | + ) |
| 135 | + |
| 136 | + # Ensure the data is filtered up to the point where lower_limit would be larger than the noise frequency limit |
| 137 | + n_harmonics = math.floor((self.noise_frequency_limit + self.notch_distance) / self.heart_rate) |
| 138 | + capture("n_harmonics", n_harmonics) |
| 139 | + |
| 140 | + for harmonic in range(1, n_harmonics + 1): |
| 141 | + data = self._filter_harmonic_with_bandstop(data, harmonic, axis_, sample_frequency_, capture) |
| 142 | + |
| 143 | + # Filter everything above noise limit |
| 144 | + sos = signal.butter( |
| 145 | + N=self.order, |
| 146 | + Wn=self.noise_frequency_limit, |
| 147 | + fs=sample_frequency_, |
| 148 | + btype="low", |
| 149 | + output="sos", |
| 150 | + ) |
| 151 | + new_data = signal.sosfiltfilt(sos, data, axis_) |
| 152 | + |
| 153 | + if isinstance(input_data, np.ndarray): |
| 154 | + capture("filtered_data", new_data) |
| 155 | + return new_data |
| 156 | + |
| 157 | + # TODO: Replace with input_data.update(...) when implemented |
| 158 | + return_object = copy.deepcopy(input_data) |
| 159 | + for attr, value in kwargs.items(): |
| 160 | + setattr(return_object, attr, value) |
| 161 | + |
| 162 | + if isinstance(return_object, ContinuousData): |
| 163 | + return_object.values = new_data |
| 164 | + elif isinstance(return_object, EITData): |
| 165 | + return_object.pixel_impedance = new_data |
| 166 | + |
| 167 | + capture("filtered_data", return_object) |
| 168 | + return return_object |
| 169 | + |
| 170 | + def _validate_arguments( |
| 171 | + self, |
| 172 | + input_data: np.ndarray | ContinuousData | EITData, |
| 173 | + sample_frequency: float | object, |
| 174 | + axis: int | object, |
| 175 | + ) -> tuple[float, int, np.ndarray]: |
| 176 | + if isinstance(input_data, ContinuousData | EITData): |
| 177 | + if sample_frequency is not MISSING: |
| 178 | + msg = "Sample frequency should not be provided when using ContinuousData or EITData." |
| 179 | + raise ValueError(msg) |
| 180 | + |
| 181 | + if axis is not MISSING: |
| 182 | + msg = "Axis should not be provided when using ContinuousData or EITData." |
| 183 | + raise ValueError(msg) |
| 184 | + |
| 185 | + if isinstance(input_data, ContinuousData): |
| 186 | + data = input_data.values |
| 187 | + sample_frequency_ = cast("float", input_data.sample_frequency) |
| 188 | + axis_ = 0 |
| 189 | + elif isinstance(input_data, EITData): |
| 190 | + data = input_data.pixel_impedance |
| 191 | + sample_frequency_ = cast("float", input_data.sample_frequency) |
| 192 | + axis_ = 0 |
| 193 | + elif isinstance(input_data, np.ndarray): |
| 194 | + data = input_data |
| 195 | + axis_ = DEFAULT_AXIS if axis is MISSING else axis |
| 196 | + axis_ = cast("int", axis_) |
| 197 | + if sample_frequency is MISSING: |
| 198 | + msg = "Sample frequency must be provided when using a numpy array." |
| 199 | + raise ValueError(msg) |
| 200 | + sample_frequency_: float = cast("float", sample_frequency) |
| 201 | + else: |
| 202 | + msg = f"Invalid input data type ({type(input_data)}). Must be a numpy array, ContinuousData, or EITData." |
| 203 | + raise TypeError(msg) |
| 204 | + |
| 205 | + if not sample_frequency_: |
| 206 | + msg = "Sample frequency must be provided." |
| 207 | + raise ValueError(msg) |
| 208 | + return sample_frequency_, axis_, data |
| 209 | + |
| 210 | + def _filter_harmonic_with_bandstop( |
| 211 | + self, |
| 212 | + data_: np.ndarray, |
| 213 | + harmonic: int, |
| 214 | + axis: int, |
| 215 | + sample_frequency: float, |
| 216 | + capture: _CaptureFunc, |
| 217 | + ) -> np.ndarray: |
| 218 | + lower_limit = self.heart_rate * harmonic - self.notch_distance |
| 219 | + upper_limit = self.heart_rate * harmonic + self.notch_distance |
| 220 | + |
| 221 | + if harmonic == 1: |
| 222 | + new_lower_limit = (self.heart_rate + self.respiratory_rate) / 2 |
| 223 | + lower_limit = max(lower_limit, new_lower_limit) |
| 224 | + |
| 225 | + sos = signal.butter( |
| 226 | + N=self.order, |
| 227 | + Wn=[lower_limit, upper_limit], |
| 228 | + fs=sample_frequency, |
| 229 | + btype="bandstop", |
| 230 | + output="sos", |
| 231 | + ) |
| 232 | + |
| 233 | + capture("frequency_bands", (lower_limit, upper_limit), append_to_list=True) |
| 234 | + |
| 235 | + return signal.sosfiltfilt(sos, data_, axis=axis) |
| 236 | + |
| 237 | + @property |
| 238 | + def plotting(self) -> FilterPlotting: |
| 239 | + """Return the plotting class for this filter.""" |
| 240 | + from eitprocessing.plotting.filter import FilterPlotting |
| 241 | + |
| 242 | + return FilterPlotting() |
0 commit comments