Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import warnings
from pathlib import Path

Expand All @@ -7,14 +8,9 @@

from .base import BaseSegment
from .baserecordingsnippets import BaseRecordingSnippets
from .core_tools import (
convert_bytes_to_str,
convert_seconds_to_str,
)
from .recording_tools import write_binary_recording


from .core_tools import convert_bytes_to_str, convert_seconds_to_str
from .job_tools import split_job_kwargs
from .recording_tools import write_binary_recording


class BaseRecording(BaseRecordingSnippets):
Expand Down Expand Up @@ -921,11 +917,11 @@ def time_to_sample_index(self, time_s):
sample_index = time_s * self.sampling_frequency
else:
sample_index = (time_s - self.t_start) * self.sampling_frequency
sample_index = round(sample_index)
sample_index = np.round(sample_index).astype(int)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging this change! Sorry if this is not relevant enough to this PR, but I thought that while I was working on time logic it would be good to fix this last small quality of life thing (vectorizing time_to_sample_index -- note that the scalar case still behaves the same).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this follows @h-mayorquin 's comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! This is nice, I think the only consideration is possible overflow for longer recordings, as int64 is capped but python int() is not capped. @h-mayorquin has been focussing on this more, but looking at a quick example below it should be fine.

int64 max value is 9,223,372,036,854,775,807. If we take a neuropixels recording, continuous for 2 months (not unfeasible these days) we have (30,000 * 60 * 60 * 24 * 60) = 165888000000 (samples per s x seconds per minute x minutes per hour x hours per day x ~days in 2 month) (please check). But, maybe in 5 years people are sampling at 100 kHz and doing year long recordings 😆 we would have max index of (3.1536e+12). So I think should be sufficient under all feasible uses, but something to consider.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting... wait, maybe I'm doing the math wrong, but don't we have:

# int64 max val     | samples/sec | sec/min | min/hr | hr/day
9223372036854775807 /     100_000 /      60 /     60 /     24
# => 1_067_519_911.6730065 days

which is quite a long time? (wolfram double check)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I think we're good, I just meant we would have with that example a max index of 3.1536e+12 out of possible 9223372036854775807 for one year, remainder 2924712.08678 which is your 1_067_519_911.6730065/365. I think 1_067_519_911.6730065 days at 100 kHz is a much better way of putting it which really shows how sufficient this is!

else:
sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1

return int(sample_index)
return sample_index

def get_num_samples(self) -> int:
"""Returns the number of samples in this signal segment
Expand Down
53 changes: 38 additions & 15 deletions src/spikeinterface/sortingcomponents/motion/motion_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from spikeinterface.preprocessing.filter import fix_dtype

from .motion_utils import ensure_time_bin_edges, ensure_time_bins


def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray:
"""
Expand Down Expand Up @@ -54,6 +56,7 @@ def interpolate_motion_on_traces(
segment_index=None,
channel_inds=None,
interpolation_time_bin_centers_s=None,
interpolation_time_bin_edges_s=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest, what is the use-case for allowing edges to be passed instead of centres? say vs. requiring centres only? I find this signature and the code necessary to handle either centres or edges a little confusing, but agree there are few option that allow this level of flexibility. I guess these options typically not user-facing anyway? i.e. most users would be using the motion pipeline and can safely ignore this.

Also, a docstring addition in Parameters for interpolation_time_bin_edges_s would be great.

spatial_interpolation_method="kriging",
spatial_interpolation_kwargs={},
dtype=None,
Expand Down Expand Up @@ -119,17 +122,26 @@ def interpolate_motion_on_traces(
total_num_chans = channel_locations.shape[0]

# -- determine the blocks of frames that will land in the same interpolation time bin
time_bins = interpolation_time_bin_centers_s
if time_bins is None:
time_bins = motion.temporal_bins_s[segment_index]
bin_s = time_bins[1] - time_bins[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check my understanding, the searchsorted on the bin_edges is functionally equivalent to this approach? (but of course searchsorted is less verbose)

bins_start = time_bins[0] - 0.5 * bin_s
# nearest bin center for each frame?
bin_inds = (times - bins_start) // bin_s
bin_inds = bin_inds.astype(int)
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None:
bin_centers_s = motion.temporal_bin_edges_s[segment_index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct here that the bin centers are assigned the temporal bin edges?

bin_edges_s = motion.temporal_bin_edges_s[segment_index]
else:
bin_centers_s, bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s)

# nearest interpolation bin:
# seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
# searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]

# hence the -1. doing it with "left" is not as nice -- we want t==b[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, I cannot believe that is not the default behaviour of "left"!

# to lead to i=1 (rounding down).
# time_bins are bin centers, but we want to snap to the nearest center.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this line

# idea is to get the left bin edges and bin the interp times.
# this is like subtracting bin_dt_s/2, but allows non-equally-spaced bins.
# it's fine to use the first bin center for the first left edge
bin_inds = np.searchsorted(bin_edges_s, times, side="right") - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think now is a good opportunity to rename some of these variables e.g. it work to call bin_edges_s and bin_centers_s maybe motion_bin_edges_s and motion_bin_centers_s and maybe bin_ids something like nearest_motion_bin_idx_per_frame? Maybe this then goes out of alignment with the other SI variable names so might not be worth it. But in this function I found it a little hard to track what is a motion time bin vs. binned AP data frames


# the time bins may not cover the whole set of times in the recording,
# so we need to clip these indices to the valid range
np.clip(bin_inds, 0, time_bins.size, out=bin_inds)
n_bins = bin_edges_s.shape[0] - 1
np.clip(bin_inds, 0, n_bins - 1, out=bin_inds)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it is worth documenting this behaviour in the Parameters. I think it makes a lot of sense, but good to clarify in case a user thinks that frames outside of the bin range will not be interpolated at all?


# -- what are the possibilities here anyway?
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1)
Expand All @@ -138,7 +150,7 @@ def interpolate_motion_on_traces(
interp_times = np.empty(total_num_chans)
current_start_index = 0
for bin_ind in bins_here:
bin_time = time_bins[bin_ind]
bin_time = bin_centers_s[bin_ind]
interp_times.fill(bin_time)
channel_motions = motion.get_displacement_at_time_and_depth(
interp_times,
Expand Down Expand Up @@ -297,6 +309,7 @@ def __init__(
p=1,
num_closest=3,
interpolation_time_bin_centers_s=None,
interpolation_time_bin_edges_s=None,
interpolation_time_bin_size_s=None,
dtype=None,
**spatial_interpolation_kwargs,
Expand Down Expand Up @@ -363,9 +376,14 @@ def __init__(

# handle manual interpolation_time_bin_centers_s
# the case where interpolation_time_bin_size_s is set is handled per-segment below
if interpolation_time_bin_centers_s is None:
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None:
if interpolation_time_bin_size_s is None:
interpolation_time_bin_centers_s = motion.temporal_bins_s
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s
else:
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s
)

for segment_index, parent_segment in enumerate(recording._recording_segments):
# finish the per-segment part of the time bin logic
Expand All @@ -375,8 +393,13 @@ def __init__(
t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end]))
halfbin = interpolation_time_bin_size_s / 2.0
segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s)
segment_interpolation_time_bin_edges_s = np.arange(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(mostly for line 390) Is it possible for interpolation_time_bin_centers_s to be None at this point anymore? If centers and edges are both None, it will be motion.temporal_bins_s, if it is passed it will not be None, and it if centers is None it will be filled in with ensure_time_bins ?

t_start, t_end + halfbin, interpolation_time_bin_size_s
)
assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,)
else:
segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index]
segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index]

rec_segment = InterpolateMotionRecordingSegment(
parent_segment,
Expand All @@ -387,6 +410,7 @@ def __init__(
channel_inds,
segment_index,
segment_interpolation_time_bins_s,
segment_interpolation_time_bin_edges_s,
dtype=dtype_,
)
self.add_recording_segment(rec_segment)
Expand Down Expand Up @@ -420,6 +444,7 @@ def __init__(
channel_inds,
segment_index,
interpolation_time_bin_centers_s,
interpolation_time_bin_edges_s,
dtype="float32",
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
Expand All @@ -429,13 +454,11 @@ def __init__(
self.channel_inds = channel_inds
self.segment_index = segment_index
self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s
self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s
self.dtype = dtype
self.motion = motion

def get_traces(self, start_frame, end_frame, channel_indices):
if self.time_vector is not None:
raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.")

if start_frame is None:
start_frame = 0
if end_frame is None:
Expand All @@ -453,7 +476,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
channel_inds=self.channel_inds,
spatial_interpolation_method=self.spatial_interpolation_method,
spatial_interpolation_kwargs=self.spatial_interpolation_kwargs,
interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s,
interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s,
)

if channel_indices is not None:
Expand Down
24 changes: 23 additions & 1 deletion src/spikeinterface/sortingcomponents/motion/motion_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
import json
import warnings
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y"
self.direction = direction
self.dim = ["x", "y", "z"].index(direction)
self.check_properties()
self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s]

def check_properties(self):
assert all(d.ndim == 2 for d in self.displacement)
Expand Down Expand Up @@ -576,3 +577,24 @@ def make_3d_motion_histograms(
motion_histograms = np.log2(1 + motion_histograms)

return motion_histograms, temporal_bin_edges, spatial_bin_edges


def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A docstring would be useful here, just to explain a) the case this is used b) brief overview of what it is doing.

If I understand correctly, we need both bin centres and bin edges. Given some bin centres, we compute the edges, or vice versa given some bin edges we compute the centres?

if time_bin_centers_s is None and time_bin_edges_s is None:
raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.")

if time_bin_centers_s is None:
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])

if time_bin_edges_s is None:
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this always be float? As we are multiplying by 0.5. If we the dtypes need to be the same, should we instead cast time_bin_centers_s to float?

time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
if time_bin_centers_s.size > 2:
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])

return time_bin_centers_s, time_bin_edges_s


def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None):
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1]
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from pathlib import Path
import warnings

import numpy as np
import pytest
import spikeinterface.core as sc
from spikeinterface import download_dataset
from spikeinterface.sortingcomponents.motion import Motion
from spikeinterface.sortingcomponents.motion.motion_interpolation import (
InterpolateMotionRecording,
correct_motion_on_peaks,
interpolate_motion,
interpolate_motion_on_traces,
)
from spikeinterface.sortingcomponents.motion import Motion
from spikeinterface.sortingcomponents.tests.common import make_dataset


Expand Down Expand Up @@ -67,18 +65,20 @@ def test_interpolate_motion_on_traces():
times = rec.get_times()[0:30000]

for method in ("kriging", "idw", "nearest"):
traces_corrected = interpolate_motion_on_traces(
traces,
times,
channel_locations,
motion,
channel_inds=None,
spatial_interpolation_method=method,
# spatial_interpolation_kwargs={},
spatial_interpolation_kwargs={"force_extrapolate": True},
)
assert traces.shape == traces_corrected.shape
assert traces.dtype == traces_corrected.dtype
for interpolation_time_bin_centers_s in (None, np.linspace(*times[[0, -1]], num=3)):
traces_corrected = interpolate_motion_on_traces(
traces,
times,
channel_locations,
motion,
channel_inds=None,
spatial_interpolation_method=method,
interpolation_time_bin_centers_s=interpolation_time_bin_centers_s,
# spatial_interpolation_kwargs={},
spatial_interpolation_kwargs={"force_extrapolate": True},
)
assert traces.shape == traces_corrected.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside the scope of this PR (as tests were like this anyways), but it would be nice to extend these tests to explicitly check the values are correct. Testing the shape will not pick up any regressions that mess up the actual computation but leave the shape in tact. Unfortunately these are the least likely to be picked up as any erroneous shapes will probably crash at runtime anyway.

I am wondering if traces is a simple 3x3 array (3 channels, 3 timepoints) it would be relatively easy to compute manually the expected results of kriging, idw and NN interpolation and check against the output of this function? Of course, this is outside the scope of the PR so feel free to ignore and I can write an issue!

assert traces.dtype == traces_corrected.dtype


def test_interpolation_simple():
Expand Down Expand Up @@ -115,6 +115,66 @@ def test_interpolation_simple():
assert np.all(traces_corrected[:, 2:] == 0)


def test_cross_band_interpolation():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a really nice test, super useful also for conceptualising the cross-band interpolation

"""Simple version of using LFP to interpolate AP data

This also tests the time vector implementation in interpolation.
The idea is to have two recordings which are all 0s with a 1 that
moves from one channel to another after 3s. They're at different
sampling frequencies. motion estimation in one sampling frequency
applied to the other should still lead to perfect correction.
"""
from spikeinterface.sortingcomponents.motion import estimate_motion

# sampling freqs and timing for AP and LFP recordings
fs_lfp = 50.0
fs_ap = 300.0
t_start = 10.0
total_duration = 5.0
nt_lfp = int(fs_lfp * total_duration)
nt_ap = int(fs_ap * total_duration)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is nt num_timepoints? could it be expanded?

t_switch = 3

# because interpolation uses bin centers logic, there will be a half
# bin offset at the change point in the AP recording.
halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp))

# channel geometry
nc = 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be num_channels or num_chans?

geom = np.c_[np.zeros(nc), np.arange(nc)]

# make an LFP recording which drifts a bit
traces_lfp = np.zeros((nt_lfp, nc))
traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0
traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0
rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp)
rec_lfp.set_dummy_probe_from_locations(geom)

# same for AP
traces_ap = np.zeros((nt_ap, nc))
traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0
traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0
rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap)
rec_ap.set_dummy_probe_from_locations(geom)

# set times for both, and silence the warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
rec_lfp.set_times(t_start + np.arange(nt_lfp) / fs_lfp)
rec_ap.set_times(t_start + np.arange(nt_ap) / fs_ap)

# estimate motion
motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True)

# nearest to keep it simple
rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2)
traces_corrected = rec_corrected.get_traces()
target = np.zeros((nt_ap, nc - 2))
target[:, 4] = 1
ii, jj = np.nonzero(traces_corrected)
assert np.array_equal(traces_corrected, target)


def test_InterpolateMotionRecording():
rec, sorting = make_dataset()
motion = make_fake_motion(rec)
Expand Down Expand Up @@ -147,6 +207,7 @@ def test_InterpolateMotionRecording():

if __name__ == "__main__":
# test_correct_motion_on_peaks()
# test_interpolate_motion_on_traces()
test_interpolation_simple()
test_InterpolateMotionRecording()
test_interpolate_motion_on_traces()
# test_interpolation_simple()
# test_InterpolateMotionRecording()
test_cross_band_interpolation()