Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class BaseRecording(BaseRecordingSnippets):
"noise_level_std_scaled",
"noise_level_mad_raw",
"noise_level_mad_scaled",
"noise_level_rms_raw",
"noise_level_rms_scaled",
]

def __init__(self, sampling_frequency: float, channel_ids: list, dtype):
Expand Down
16 changes: 3 additions & 13 deletions src/spikeinterface/extractors/cbin_ibl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import probeinterface

from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts
from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts_from_probe
from spikeinterface.core.core_tools import define_function_from_class


Expand Down Expand Up @@ -44,22 +44,13 @@ class CompressedBinaryIblExtractor(BaseRecording):

installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n"

def __init__(
self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None
):
def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None):
from neo.rawio.spikeglxrawio import read_meta_file

try:
import mtscomp
except ImportError:
raise ImportError(self.installation_mesg)
if cbin_file is not None:
warnings.warn(
"The `cbin_file` argument is deprecated and will be removed in version 0.104.0, please use `cbin_file_path` instead",
DeprecationWarning,
stacklevel=2,
)
cbin_file_path = cbin_file
if cbin_file_path is None:
folder_path = Path(folder_path)
# check bands
Expand Down Expand Up @@ -124,8 +115,7 @@ def __init__(
num_channels_per_adc = 16
else: # NP1.0
num_channels_per_adc = 12

sample_shifts = get_neuropixels_sample_shifts(self.get_num_channels(), num_channels_per_adc)
sample_shifts = get_neuropixels_sample_shifts_from_probe(probe, num_channels_per_adc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed this, which now needs to be probe directly!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good with me, I also fixed the tests so the CI passes.

self.set_property("inter_sample_shift", sample_shifts)

self._kwargs = {
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FilterRecording(BasePreprocessor):
def __init__(
self,
recording,
band=[300.0, 6000.0],
band=(300.0, 6000.0),
btype="bandpass",
filter_order=5,
ftype="butter",
Expand Down Expand Up @@ -370,7 +370,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms="auto", dtype=None, **f
def causal_filter(
recording,
direction="forward",
band=[300.0, 6000.0],
band=(300.0, 6000.0),
btype="bandpass",
filter_order=5,
ftype="butter",
Expand Down
93 changes: 74 additions & 19 deletions src/spikeinterface/preprocessing/highpass_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import numpy as np

from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from .filter import fix_dtype
from spikeinterface.core import order_channels_by_depth, get_chunk_with_margin
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording
from spikeinterface.preprocessing.filter import fix_dtype
from spikeinterface.core import order_channels_by_depth, get_chunk_with_margin, get_noise_levels
from spikeinterface.core.core_tools import define_function_handling_dict_from_class


Expand Down Expand Up @@ -48,8 +48,17 @@ class HighpassSpatialFilterRecording(BasePreprocessor):
Order of spatial butterworth filter
highpass_butter_wn : float, default: 0.01
Critical frequency (with respect to Nyquist) of spatial butterworth filter
epsilon : float, default: 0.003
Value multiplied to RMS values to avoid division by zero during AGC.
random_slice_kwargs : dict | None, default: None
If not None, dictionary of arguments to be passed to `get_noise_levels` when computing
noise levels.
dtype : dtype, default: None
The dtype of the output traces. If None, the dtype is the same as the input traces
rms_values : np.ndarray | None, default: None
If not None, array of RMS values for each channel to be used during AGC. If None, RMS values are computed
from the recording. This is used to cache pre-computed RMS values, which are only computed once at
initialization.

Returns
-------
Expand All @@ -66,15 +75,18 @@ class HighpassSpatialFilterRecording(BasePreprocessor):

def __init__(
self,
recording,
recording: BaseRecording,
n_channel_pad=60,
n_channel_taper=0,
direction="y",
apply_agc=True,
agc_window_length_s=0.1,
highpass_butter_order=3,
highpass_butter_wn=0.01,
epsilon=0.003,
random_slice_kwargs=None,
dtype=None,
rms_values=None,
):
BasePreprocessor.__init__(self, recording)

Expand Down Expand Up @@ -115,6 +127,14 @@ def __init__(
if not apply_agc:
agc_window_length_s = None

# Compute or retrieve RMS values
if rms_values is None:
if "noise_level_rms_raw" in recording.get_property_keys():
rms_values = recording.get_property("noise_level_rms_raw")
else:
random_slice_kwargs = {} if random_slice_kwargs is None else random_slice_kwargs
rms_values = get_noise_levels(recording, method="rms", return_scaled=False, **random_slice_kwargs)

# Pre-compute spatial filtering parameters
butter_kwargs = dict(btype="highpass", N=highpass_butter_order, Wn=highpass_butter_wn)
sos_filter = scipy.signal.butter(**butter_kwargs, output="sos")
Expand All @@ -133,6 +153,8 @@ def __init__(
order_f,
order_r,
dtype=dtype,
epsilon=epsilon,
rms_values=rms_values,
)
self.add_recording_segment(rec_segment)

Expand All @@ -145,6 +167,7 @@ def __init__(
agc_window_length_s=agc_window_length_s,
highpass_butter_order=highpass_butter_order,
highpass_butter_wn=highpass_butter_wn,
rms_values=rms_values,
)


Expand All @@ -161,6 +184,8 @@ def __init__(
order_f,
order_r,
dtype,
epsilon,
rms_values,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.parent_recording_segment = parent_recording_segment
Expand All @@ -185,6 +210,7 @@ def __init__(
# get filter params
self.sos_filter = sos_filter
self.dtype = dtype
self.epsilon_values_for_agc = epsilon * np.array(rms_values)

def get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
Expand All @@ -207,8 +233,9 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = traces.copy()

# apply AGC and keep the gains
traces = traces.astype(np.float32)
if self.window is not None:
traces, agc_gains = agc(traces, window=self.window)
traces, agc_gains = agc(traces, window=self.window, epsilons=self.epsilon_values_for_agc)
else:
agc_gains = None
# pad the array with a mirrored version of itself and apply a cosine taper
Expand Down Expand Up @@ -255,36 +282,56 @@ def get_traces(self, start_frame, end_frame, channel_indices):
# -----------------------------------------------------------------------------------------------


def agc(traces, window, epsilon=1e-8):
def agc(traces, window, epsilons):
"""
Automatic gain control
w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8)
such as w_agc * gain = w
:param traces: seismic array (sample last dimension)
:param window_length: window length (secs) (original default 0.5)
:param si: sampling interval (secs) (original default 0.002)
:param epsilon: whitening (useful mainly for synthetic data)
:return: AGC data array, gain applied to data

Parameters
----------
traces : np.ndarray
Input traces
window : np.ndarray
Window to use for AGC (1D array)
epsilons : np.ndarray[float]
Epsilon values for each channel to avoid division by zero

Returns
-------
agc_traces : np.ndarray
AGC applied traces
gain : np.ndarray
Gain applied to the traces
"""
import scipy.signal

gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0)

gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :]

dead_channels = np.sum(gain, axis=0) == 0

traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels]
traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilons, gain[:, ~dead_channels])

return traces, gain


def fcn_extrap(x, f, bounds):
"""
Extrapolates a flat value before and after bounds
x: array to be filtered
f: function to be applied between bounds (cf. fcn_cosine below)
bounds: 2 elements list or np.array

Parameters
----------
x : np.ndarray
Input array
f : function
Function to be applied between bounds
bounds : list or np.ndarray
2 elements list or array defining the bounds

Returns
-------
y : np.ndarray
Output array
"""
y = f(x)
y[x < bounds[0]] = f(bounds[0])
Expand All @@ -298,8 +345,16 @@ def fcn_cosine(bounds):
values <= bounds[0]: values
values < bounds[0] < bounds[1] : cosine taper
values < bounds[1]: bounds[1]
:param bounds:
:return: lambda function

Parameters
----------
bounds : list or np.ndarray
2 elements list or array defining the bounds

Returns
-------
func : function
Lambda function implementing the soft thresholding with cosine taper
"""

def _cos(x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from copy import deepcopy

import spikeinterface as si
import spikeinterface.core as si
import spikeinterface.preprocessing as spre
import spikeinterface.extractors as se
from spikeinterface.core import generate_recording
Expand All @@ -24,7 +24,7 @@


@pytest.mark.skipif(
importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
reason="Only local. Requires ibl-neuropixel install",
)
@pytest.mark.parametrize("lagc", [False, 1, 300])
Expand All @@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc):
use DEBUG = true to visualise.

"""
import spikeglx
import neurodsp.voltage as voltage
import ibldsp.voltage
import neuropixel

options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None)
print(options)

ibl_data, si_recording = get_ibl_si_data()

si_filtered, _ = run_si_highpass_filter(si_recording, **options)
local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
si_recording = spre.astype(si_recording, "float")
recording_ps = spre.phase_shift(si_recording)
recording_hp = spre.highpass_filter(recording_ps, freq_min=300, filter_order=3)
recording_hps = spre.highpass_spatial_filter(recording_hp)
raw = si_recording.get_traces().astype(np.float32).T * neuropixel.S2V_AP
si_filtered = recording_hps.get_traces().astype(np.float32).T * neuropixel.S2V_AP

ibl_filtered = run_ibl_highpass_filter(ibl_data.copy(), **options)
destripe = ibldsp.voltage.destripe(raw, fs=30_000, neuropixel_version=1)

if DEBUG:
fig, axs = plt.subplots(ncols=4)
axs[0].imshow(si_recording.get_traces(return_in_uV=True))
axs[0].set_title("SI Raw")
axs[1].imshow(ibl_data.T)
axs[1].set_title("IBL Raw")
axs[2].imshow(si_filtered)
axs[2].set_title("SI Filtered ")
axs[3].imshow(ibl_filtered)
axs[3].set_title("IBL Filtered")
from viewephys.gui import viewephys

eqc = {}
eqc["si_filtered"] = viewephys(si_filtered, fs=30_000, title="si_filtered")
eqc["ibl_filtered"] = viewephys(destripe, fs=30_000, title="ibl_filtered")

assert np.allclose(
si_filtered, ibl_filtered * 1e6, atol=1e-01, rtol=0
) # the differences are entired due to scaling on data load.
np.testing.assert_allclose(si_filtered[12:120, 300:800], destripe[12:120, 300:800], atol=1e-05, rtol=0)


@pytest.mark.parametrize("ntr_pad", [None, 0, 31])
Expand Down Expand Up @@ -140,24 +136,6 @@ def test_dtype_stability(dtype):
# ----------------------------------------------------------------------------------------------------------------------


def get_ibl_si_data():
"""
Set fixture to session to ensure origional data is not changed.
"""
import spikeglx

local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
ibl_recording = spikeglx.Reader(
local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True
)
ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel

si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
si_recording = spre.astype(si_recording, dtype="float32")

return ibl_data, si_recording


def process_args_for_si(si_recording, lagc):
""""""
if isinstance(lagc, bool) and not lagc:
Expand Down Expand Up @@ -215,9 +193,10 @@ def run_si_highpass_filter(si_recording, ntr_pad, ntr_tap, lagc, butter_kwargs,


def run_ibl_highpass_filter(ibl_data, ntr_pad, ntr_tap, lagc, butter_kwargs):
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
import ibldsp.voltage

ibl_filtered = voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
ibl_filtered = ibldsp.voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T

return ibl_filtered

Expand Down