-
Notifications
You must be signed in to change notification settings - Fork 223
Fix a cross-band interpolation bug, and allow time_vector in interpolate_motion #3517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
507b6b3
4e38ac1
726170b
e791fe1
82e2600
d8f39b5
0a201e1
ad00beb
28527d2
b3b3fcf
b02860e
df24840
91fb732
b80bad7
c890603
6d2e479
ee29fae
b4c91a0
38e0ada
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -6,6 +6,8 @@ | |||||
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment | ||||||
from spikeinterface.preprocessing.filter import fix_dtype | ||||||
|
||||||
from .motion_utils import ensure_time_bin_edges, ensure_time_bins | ||||||
|
||||||
|
||||||
def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray: | ||||||
""" | ||||||
|
@@ -54,6 +56,7 @@ def interpolate_motion_on_traces( | |||||
segment_index=None, | ||||||
channel_inds=None, | ||||||
interpolation_time_bin_centers_s=None, | ||||||
interpolation_time_bin_edges_s=None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of interest, what is the use-case for allowing edges to be passed instead of centres? say vs. requiring centres only? I find this signature and the code necessary to handle either centres or edges a little confusing, but agree there are few option that allow this level of flexibility. I guess these options typically not user-facing anyway? i.e. most users would be using the motion pipeline and can safely ignore this. Also, a docstring addition in Parameters for |
||||||
spatial_interpolation_method="kriging", | ||||||
spatial_interpolation_kwargs={}, | ||||||
dtype=None, | ||||||
|
@@ -119,17 +122,26 @@ def interpolate_motion_on_traces( | |||||
total_num_chans = channel_locations.shape[0] | ||||||
|
||||||
# -- determine the blocks of frames that will land in the same interpolation time bin | ||||||
time_bins = interpolation_time_bin_centers_s | ||||||
if time_bins is None: | ||||||
time_bins = motion.temporal_bins_s[segment_index] | ||||||
bin_s = time_bins[1] - time_bins[0] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to check my understanding, the |
||||||
bins_start = time_bins[0] - 0.5 * bin_s | ||||||
# nearest bin center for each frame? | ||||||
bin_inds = (times - bins_start) // bin_s | ||||||
bin_inds = bin_inds.astype(int) | ||||||
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: | ||||||
bin_centers_s = motion.temporal_bin_edges_s[segment_index] | ||||||
|
||||||
bin_edges_s = motion.temporal_bin_edges_s[segment_index] | ||||||
else: | ||||||
bin_centers_s, bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s) | ||||||
|
||||||
# nearest interpolation bin: | ||||||
# seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] | ||||||
|
# seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] | |
# searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, I cannot believe that is not the default behaviour of "left"!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand this line
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think now is a good opportunity to rename some of these variables e.g. it work to call bin_edges_s
and bin_centers_s
maybe motion_bin_edges_s
and motion_bin_centers_s
and maybe bin_ids
something like nearest_motion_bin_idx_per_frame
? Maybe this then goes out of alignment with the other SI variable names so might not be worth it. But in this function I found it a little hard to track what is a motion time bin vs. binned AP data frames
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it is worth documenting this behaviour in the Parameters
. I think it makes a lot of sense, but good to clarify in case a user thinks that frames outside of the bin range will not be interpolated at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(mostly for line 390) Is it possible for interpolation_time_bin_centers_s
to be None
at this point anymore? If centers and edges are both None
, it will be motion.temporal_bins_s
, if it is passed it will not be None
, and it if centers is None
it will be filled in with ensure_time_bins
?
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import warnings | ||
import json | ||
import warnings | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
|
@@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" | |
self.direction = direction | ||
self.dim = ["x", "y", "z"].index(direction) | ||
self.check_properties() | ||
self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s] | ||
|
||
def check_properties(self): | ||
assert all(d.ndim == 2 for d in self.displacement) | ||
|
@@ -576,3 +577,24 @@ def make_3d_motion_histograms( | |
motion_histograms = np.log2(1 + motion_histograms) | ||
|
||
return motion_histograms, temporal_bin_edges, spatial_bin_edges | ||
|
||
|
||
def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A docstring would be useful here, just to explain a) the case this is used b) brief overview of what it is doing. If I understand correctly, we need both bin centres and bin edges. Given some bin centres, we compute the edges, or vice versa given some bin edges we compute the centres? |
||
if time_bin_centers_s is None and time_bin_edges_s is None: | ||
raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.") | ||
|
||
if time_bin_centers_s is None: | ||
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2 | ||
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) | ||
|
||
if time_bin_edges_s is None: | ||
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this always be float? As we are multiplying by |
||
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]] | ||
if time_bin_centers_s.size > 2: | ||
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1]) | ||
|
||
return time_bin_centers_s, time_bin_edges_s | ||
|
||
|
||
def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): | ||
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,14 @@ | ||
from pathlib import Path | ||
import warnings | ||
|
||
import numpy as np | ||
import pytest | ||
import spikeinterface.core as sc | ||
from spikeinterface import download_dataset | ||
from spikeinterface.sortingcomponents.motion import Motion | ||
from spikeinterface.sortingcomponents.motion.motion_interpolation import ( | ||
InterpolateMotionRecording, | ||
correct_motion_on_peaks, | ||
interpolate_motion, | ||
interpolate_motion_on_traces, | ||
) | ||
from spikeinterface.sortingcomponents.motion import Motion | ||
from spikeinterface.sortingcomponents.tests.common import make_dataset | ||
|
||
|
||
|
@@ -67,18 +65,20 @@ def test_interpolate_motion_on_traces(): | |
times = rec.get_times()[0:30000] | ||
|
||
for method in ("kriging", "idw", "nearest"): | ||
traces_corrected = interpolate_motion_on_traces( | ||
traces, | ||
times, | ||
channel_locations, | ||
motion, | ||
channel_inds=None, | ||
spatial_interpolation_method=method, | ||
# spatial_interpolation_kwargs={}, | ||
spatial_interpolation_kwargs={"force_extrapolate": True}, | ||
) | ||
assert traces.shape == traces_corrected.shape | ||
assert traces.dtype == traces_corrected.dtype | ||
for interpolation_time_bin_centers_s in (None, np.linspace(*times[[0, -1]], num=3)): | ||
traces_corrected = interpolate_motion_on_traces( | ||
traces, | ||
times, | ||
channel_locations, | ||
motion, | ||
channel_inds=None, | ||
spatial_interpolation_method=method, | ||
interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, | ||
# spatial_interpolation_kwargs={}, | ||
spatial_interpolation_kwargs={"force_extrapolate": True}, | ||
) | ||
assert traces.shape == traces_corrected.shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Outside the scope of this PR (as tests were like this anyways), but it would be nice to extend these tests to explicitly check the values are correct. Testing the shape will not pick up any regressions that mess up the actual computation but leave the shape in tact. Unfortunately these are the least likely to be picked up as any erroneous shapes will probably crash at runtime anyway. I am wondering if traces is a simple 3x3 array (3 channels, 3 timepoints) it would be relatively easy to compute manually the expected results of kriging, idw and NN interpolation and check against the output of this function? Of course, this is outside the scope of the PR so feel free to ignore and I can write an issue! |
||
assert traces.dtype == traces_corrected.dtype | ||
|
||
|
||
def test_interpolation_simple(): | ||
|
@@ -115,6 +115,66 @@ def test_interpolation_simple(): | |
assert np.all(traces_corrected[:, 2:] == 0) | ||
|
||
|
||
def test_cross_band_interpolation(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a really nice test, super useful also for conceptualising the cross-band interpolation |
||
"""Simple version of using LFP to interpolate AP data | ||
|
||
This also tests the time vector implementation in interpolation. | ||
The idea is to have two recordings which are all 0s with a 1 that | ||
moves from one channel to another after 3s. They're at different | ||
sampling frequencies. motion estimation in one sampling frequency | ||
applied to the other should still lead to perfect correction. | ||
""" | ||
from spikeinterface.sortingcomponents.motion import estimate_motion | ||
|
||
# sampling freqs and timing for AP and LFP recordings | ||
fs_lfp = 50.0 | ||
fs_ap = 300.0 | ||
t_start = 10.0 | ||
total_duration = 5.0 | ||
nt_lfp = int(fs_lfp * total_duration) | ||
nt_ap = int(fs_ap * total_duration) | ||
|
||
t_switch = 3 | ||
|
||
# because interpolation uses bin centers logic, there will be a half | ||
# bin offset at the change point in the AP recording. | ||
halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp)) | ||
|
||
# channel geometry | ||
nc = 10 | ||
|
||
geom = np.c_[np.zeros(nc), np.arange(nc)] | ||
|
||
# make an LFP recording which drifts a bit | ||
traces_lfp = np.zeros((nt_lfp, nc)) | ||
traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0 | ||
traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0 | ||
rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp) | ||
rec_lfp.set_dummy_probe_from_locations(geom) | ||
|
||
# same for AP | ||
traces_ap = np.zeros((nt_ap, nc)) | ||
traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0 | ||
traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0 | ||
rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap) | ||
rec_ap.set_dummy_probe_from_locations(geom) | ||
|
||
# set times for both, and silence the warning | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", category=UserWarning) | ||
rec_lfp.set_times(t_start + np.arange(nt_lfp) / fs_lfp) | ||
rec_ap.set_times(t_start + np.arange(nt_ap) / fs_ap) | ||
|
||
# estimate motion | ||
motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True) | ||
|
||
# nearest to keep it simple | ||
rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2) | ||
traces_corrected = rec_corrected.get_traces() | ||
target = np.zeros((nt_ap, nc - 2)) | ||
target[:, 4] = 1 | ||
ii, jj = np.nonzero(traces_corrected) | ||
assert np.array_equal(traces_corrected, target) | ||
|
||
|
||
def test_InterpolateMotionRecording(): | ||
rec, sorting = make_dataset() | ||
motion = make_fake_motion(rec) | ||
|
@@ -147,6 +207,7 @@ def test_InterpolateMotionRecording(): | |
|
||
if __name__ == "__main__": | ||
# test_correct_motion_on_peaks() | ||
# test_interpolate_motion_on_traces() | ||
test_interpolation_simple() | ||
test_InterpolateMotionRecording() | ||
test_interpolate_motion_on_traces() | ||
# test_interpolation_simple() | ||
# test_InterpolateMotionRecording() | ||
test_cross_band_interpolation() | ||
cwindolf marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flagging this change! Sorry if this is not relevant enough to this PR, but I thought that while I was working on time logic it would be good to fix this last small quality of life thing (vectorizing time_to_sample_index -- note that the scalar case still behaves the same).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this follows @h-mayorquin 's comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! This is nice, I think the only consideration is possible overflow for longer recordings, as
int64
is capped but pythonint()
is not capped. @h-mayorquin has been focussing on this more, but looking at a quick example below it should be fine.int64
max value is9,223,372,036,854,775,807
. If we take a neuropixels recording, continuous for 2 months (not unfeasible these days) we have(30,000 * 60 * 60 * 24 * 60) = 165888000000
(samples per s x seconds per minute x minutes per hour x hours per day x ~days in 2 month) (please check). But, maybe in 5 years people are sampling at 100 kHz and doing year long recordings 😆 we would have max index of(3.1536e+12)
. So I think should be sufficient under all feasible uses, but something to consider.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting... wait, maybe I'm doing the math wrong, but don't we have:
which is quite a long time? (wolfram double check)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I think we're good, I just meant we would have with that example a max index of
3.1536e+12
out of possible9223372036854775807
for one year, remainder2924712.08678
which is your1_067_519_911.6730065/365
. I think1_067_519_911.6730065 days
at 100 kHz is a much better way of putting it which really shows how sufficient this is!