Skip to content

Commit d52b032

Browse files
olichealejoe91
andauthored
Fix bug spatial filter #4175 (#4286)
Co-authored-by: Alessio Buccino <[email protected]>
1 parent 7130d4b commit d52b032

File tree

5 files changed

+103
-77
lines changed

5 files changed

+103
-77
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class BaseRecording(BaseRecordingSnippets):
3737
"noise_level_std_scaled",
3838
"noise_level_mad_raw",
3939
"noise_level_mad_scaled",
40+
"noise_level_rms_raw",
41+
"noise_level_rms_scaled",
4042
]
4143

4244
def __init__(self, sampling_frequency: float, channel_ids: list, dtype):

src/spikeinterface/extractors/cbin_ibl.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import probeinterface
88

99
from spikeinterface.core import BaseRecording, BaseRecordingSegment
10-
from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts
10+
from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts_from_probe
1111
from spikeinterface.core.core_tools import define_function_from_class
1212

1313

@@ -44,22 +44,13 @@ class CompressedBinaryIblExtractor(BaseRecording):
4444

4545
installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n"
4646

47-
def __init__(
48-
self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None
49-
):
47+
def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None):
5048
from neo.rawio.spikeglxrawio import read_meta_file
5149

5250
try:
5351
import mtscomp
5452
except ImportError:
5553
raise ImportError(self.installation_mesg)
56-
if cbin_file is not None:
57-
warnings.warn(
58-
"The `cbin_file` argument is deprecated and will be removed in version 0.104.0, please use `cbin_file_path` instead",
59-
DeprecationWarning,
60-
stacklevel=2,
61-
)
62-
cbin_file_path = cbin_file
6354
if cbin_file_path is None:
6455
folder_path = Path(folder_path)
6556
# check bands
@@ -124,8 +115,7 @@ def __init__(
124115
num_channels_per_adc = 16
125116
else: # NP1.0
126117
num_channels_per_adc = 12
127-
128-
sample_shifts = get_neuropixels_sample_shifts(self.get_num_channels(), num_channels_per_adc)
118+
sample_shifts = get_neuropixels_sample_shifts_from_probe(probe, num_channels_per_adc)
129119
self.set_property("inter_sample_shift", sample_shifts)
130120

131121
self._kwargs = {

src/spikeinterface/preprocessing/filter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class FilterRecording(BasePreprocessor):
7979
def __init__(
8080
self,
8181
recording,
82-
band=[300.0, 6000.0],
82+
band=(300.0, 6000.0),
8383
btype="bandpass",
8484
filter_order=5,
8585
ftype="butter",
@@ -370,7 +370,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms="auto", dtype=None, **f
370370
def causal_filter(
371371
recording,
372372
direction="forward",
373-
band=[300.0, 6000.0],
373+
band=(300.0, 6000.0),
374374
btype="bandpass",
375375
filter_order=5,
376376
ftype="butter",

src/spikeinterface/preprocessing/highpass_spatial_filter.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import numpy as np
44

5-
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
6-
from .filter import fix_dtype
7-
from spikeinterface.core import order_channels_by_depth, get_chunk_with_margin
5+
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording
6+
from spikeinterface.preprocessing.filter import fix_dtype
7+
from spikeinterface.core import order_channels_by_depth, get_chunk_with_margin, get_noise_levels
88
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
99

1010

@@ -48,8 +48,17 @@ class HighpassSpatialFilterRecording(BasePreprocessor):
4848
Order of spatial butterworth filter
4949
highpass_butter_wn : float, default: 0.01
5050
Critical frequency (with respect to Nyquist) of spatial butterworth filter
51+
epsilon : float, default: 0.003
52+
Value multiplied to RMS values to avoid division by zero during AGC.
53+
random_slice_kwargs : dict | None, default: None
54+
If not None, dictionary of arguments to be passed to `get_noise_levels` when computing
55+
noise levels.
5156
dtype : dtype, default: None
5257
The dtype of the output traces. If None, the dtype is the same as the input traces
58+
rms_values : np.ndarray | None, default: None
59+
If not None, array of RMS values for each channel to be used during AGC. If None, RMS values are computed
60+
from the recording. This is used to cache pre-computed RMS values, which are only computed once at
61+
initialization.
5362
5463
Returns
5564
-------
@@ -66,15 +75,18 @@ class HighpassSpatialFilterRecording(BasePreprocessor):
6675

6776
def __init__(
6877
self,
69-
recording,
78+
recording: BaseRecording,
7079
n_channel_pad=60,
7180
n_channel_taper=0,
7281
direction="y",
7382
apply_agc=True,
7483
agc_window_length_s=0.1,
7584
highpass_butter_order=3,
7685
highpass_butter_wn=0.01,
86+
epsilon=0.003,
87+
random_slice_kwargs=None,
7788
dtype=None,
89+
rms_values=None,
7890
):
7991
BasePreprocessor.__init__(self, recording)
8092

@@ -115,6 +127,14 @@ def __init__(
115127
if not apply_agc:
116128
agc_window_length_s = None
117129

130+
# Compute or retrieve RMS values
131+
if rms_values is None:
132+
if "noise_level_rms_raw" in recording.get_property_keys():
133+
rms_values = recording.get_property("noise_level_rms_raw")
134+
else:
135+
random_slice_kwargs = {} if random_slice_kwargs is None else random_slice_kwargs
136+
rms_values = get_noise_levels(recording, method="rms", return_scaled=False, **random_slice_kwargs)
137+
118138
# Pre-compute spatial filtering parameters
119139
butter_kwargs = dict(btype="highpass", N=highpass_butter_order, Wn=highpass_butter_wn)
120140
sos_filter = scipy.signal.butter(**butter_kwargs, output="sos")
@@ -133,6 +153,8 @@ def __init__(
133153
order_f,
134154
order_r,
135155
dtype=dtype,
156+
epsilon=epsilon,
157+
rms_values=rms_values,
136158
)
137159
self.add_recording_segment(rec_segment)
138160

@@ -145,6 +167,7 @@ def __init__(
145167
agc_window_length_s=agc_window_length_s,
146168
highpass_butter_order=highpass_butter_order,
147169
highpass_butter_wn=highpass_butter_wn,
170+
rms_values=rms_values,
148171
)
149172

150173

@@ -161,6 +184,8 @@ def __init__(
161184
order_f,
162185
order_r,
163186
dtype,
187+
epsilon,
188+
rms_values,
164189
):
165190
BasePreprocessorSegment.__init__(self, parent_recording_segment)
166191
self.parent_recording_segment = parent_recording_segment
@@ -185,6 +210,7 @@ def __init__(
185210
# get filter params
186211
self.sos_filter = sos_filter
187212
self.dtype = dtype
213+
self.epsilon_values_for_agc = epsilon * np.array(rms_values)
188214

189215
def get_traces(self, start_frame, end_frame, channel_indices):
190216
if channel_indices is None:
@@ -207,8 +233,9 @@ def get_traces(self, start_frame, end_frame, channel_indices):
207233
traces = traces.copy()
208234

209235
# apply AGC and keep the gains
236+
traces = traces.astype(np.float32)
210237
if self.window is not None:
211-
traces, agc_gains = agc(traces, window=self.window)
238+
traces, agc_gains = agc(traces, window=self.window, epsilons=self.epsilon_values_for_agc)
212239
else:
213240
agc_gains = None
214241
# pad the array with a mirrored version of itself and apply a cosine taper
@@ -255,36 +282,56 @@ def get_traces(self, start_frame, end_frame, channel_indices):
255282
# -----------------------------------------------------------------------------------------------
256283

257284

258-
def agc(traces, window, epsilon=1e-8):
285+
def agc(traces, window, epsilons):
259286
"""
260287
Automatic gain control
261288
w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8)
262289
such as w_agc * gain = w
263-
:param traces: seismic array (sample last dimension)
264-
:param window_length: window length (secs) (original default 0.5)
265-
:param si: sampling interval (secs) (original default 0.002)
266-
:param epsilon: whitening (useful mainly for synthetic data)
267-
:return: AGC data array, gain applied to data
290+
291+
Parameters
292+
----------
293+
traces : np.ndarray
294+
Input traces
295+
window : np.ndarray
296+
Window to use for AGC (1D array)
297+
epsilons : np.ndarray[float]
298+
Epsilon values for each channel to avoid division by zero
299+
300+
Returns
301+
-------
302+
agc_traces : np.ndarray
303+
AGC applied traces
304+
gain : np.ndarray
305+
Gain applied to the traces
268306
"""
269307
import scipy.signal
270308

271309
gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0)
272310

273-
gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :]
274-
275311
dead_channels = np.sum(gain, axis=0) == 0
276312

277-
traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels]
313+
traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilons, gain[:, ~dead_channels])
278314

279315
return traces, gain
280316

281317

282318
def fcn_extrap(x, f, bounds):
283319
"""
284320
Extrapolates a flat value before and after bounds
285-
x: array to be filtered
286-
f: function to be applied between bounds (cf. fcn_cosine below)
287-
bounds: 2 elements list or np.array
321+
322+
Parameters
323+
----------
324+
x : np.ndarray
325+
Input array
326+
f : function
327+
Function to be applied between bounds
328+
bounds : list or np.ndarray
329+
2 elements list or array defining the bounds
330+
331+
Returns
332+
-------
333+
y : np.ndarray
334+
Output array
288335
"""
289336
y = f(x)
290337
y[x < bounds[0]] = f(bounds[0])
@@ -298,8 +345,16 @@ def fcn_cosine(bounds):
298345
values <= bounds[0]: values
299346
values < bounds[0] < bounds[1] : cosine taper
300347
values < bounds[1]: bounds[1]
301-
:param bounds:
302-
:return: lambda function
348+
349+
Parameters
350+
----------
351+
bounds : list or np.ndarray
352+
2 elements list or array defining the bounds
353+
354+
Returns
355+
-------
356+
func : function
357+
Lambda function implementing the soft thresholding with cosine taper
303358
"""
304359

305360
def _cos(x):

src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from copy import deepcopy
55

6-
import spikeinterface as si
6+
import spikeinterface.core as si
77
import spikeinterface.preprocessing as spre
88
import spikeinterface.extractors as se
99
from spikeinterface.core import generate_recording
@@ -24,7 +24,7 @@
2424

2525

2626
@pytest.mark.skipif(
27-
importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
27+
importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
2828
reason="Only local. Requires ibl-neuropixel install",
2929
)
3030
@pytest.mark.parametrize("lagc", [False, 1, 300])
@@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc):
5151
use DEBUG = true to visualise.
5252
5353
"""
54-
import spikeglx
55-
import neurodsp.voltage as voltage
54+
import ibldsp.voltage
55+
import neuropixel
5656

57-
options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None)
58-
print(options)
59-
60-
ibl_data, si_recording = get_ibl_si_data()
61-
62-
si_filtered, _ = run_si_highpass_filter(si_recording, **options)
57+
local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
58+
si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
59+
si_recording = spre.astype(si_recording, "float")
60+
recording_ps = spre.phase_shift(si_recording)
61+
recording_hp = spre.highpass_filter(recording_ps, freq_min=300, filter_order=3)
62+
recording_hps = spre.highpass_spatial_filter(recording_hp)
63+
raw = si_recording.get_traces().astype(np.float32).T * neuropixel.S2V_AP
64+
si_filtered = recording_hps.get_traces().astype(np.float32).T * neuropixel.S2V_AP
6365

64-
ibl_filtered = run_ibl_highpass_filter(ibl_data.copy(), **options)
66+
destripe = ibldsp.voltage.destripe(raw, fs=30_000, neuropixel_version=1)
6567

6668
if DEBUG:
67-
fig, axs = plt.subplots(ncols=4)
68-
axs[0].imshow(si_recording.get_traces(return_in_uV=True))
69-
axs[0].set_title("SI Raw")
70-
axs[1].imshow(ibl_data.T)
71-
axs[1].set_title("IBL Raw")
72-
axs[2].imshow(si_filtered)
73-
axs[2].set_title("SI Filtered ")
74-
axs[3].imshow(ibl_filtered)
75-
axs[3].set_title("IBL Filtered")
69+
from viewephys.gui import viewephys
70+
71+
eqc = {}
72+
eqc["si_filtered"] = viewephys(si_filtered, fs=30_000, title="si_filtered")
73+
eqc["ibl_filtered"] = viewephys(destripe, fs=30_000, title="ibl_filtered")
7674

77-
assert np.allclose(
78-
si_filtered, ibl_filtered * 1e6, atol=1e-01, rtol=0
79-
) # the differences are entired due to scaling on data load.
75+
np.testing.assert_allclose(si_filtered[12:120, 300:800], destripe[12:120, 300:800], atol=1e-05, rtol=0)
8076

8177

8278
@pytest.mark.parametrize("ntr_pad", [None, 0, 31])
@@ -140,24 +136,6 @@ def test_dtype_stability(dtype):
140136
# ----------------------------------------------------------------------------------------------------------------------
141137

142138

143-
def get_ibl_si_data():
144-
"""
145-
Set fixture to session to ensure origional data is not changed.
146-
"""
147-
import spikeglx
148-
149-
local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
150-
ibl_recording = spikeglx.Reader(
151-
local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True
152-
)
153-
ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel
154-
155-
si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
156-
si_recording = spre.astype(si_recording, dtype="float32")
157-
158-
return ibl_data, si_recording
159-
160-
161139
def process_args_for_si(si_recording, lagc):
162140
""""""
163141
if isinstance(lagc, bool) and not lagc:
@@ -215,9 +193,10 @@ def run_si_highpass_filter(si_recording, ntr_pad, ntr_tap, lagc, butter_kwargs,
215193

216194

217195
def run_ibl_highpass_filter(ibl_data, ntr_pad, ntr_tap, lagc, butter_kwargs):
218-
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
196+
import ibldsp.voltage
219197

220-
ibl_filtered = voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T
198+
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
199+
ibl_filtered = ibldsp.voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T
221200

222201
return ibl_filtered
223202

0 commit comments

Comments
 (0)