Skip to content

Commit 123e862

Browse files
authored
Implement Add/SubtractRecording classes and +- operators (#4238)
1 parent fca62da commit 123e862

File tree

3 files changed

+142
-0
lines changed

3 files changed

+142
-0
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,16 @@ def _repr_html_(self, display_name=True):
161161
html_repr = html_header + html_segments + html_channel_ids + html_extra
162162
return html_repr
163163

164+
def __add__(self, other):
165+
from .operatorrecordings import AddRecordings
166+
167+
return AddRecordings(self, other)
168+
169+
def __sub__(self, other):
170+
from .operatorrecordings import SubtractRecordings
171+
172+
return SubtractRecordings(self, other)
173+
164174
def get_num_segments(self) -> int:
165175
"""
166176
Returns the number of segments.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from spikeinterface.core import BaseRecording, BaseRecordingSegment
2+
from spikeinterface.core.recording_tools import get_rec_attributes, do_recording_attributes_match
3+
4+
5+
class BaseOperatorRecording(BaseRecording):
6+
"""Base class for operator recordings."""
7+
8+
def __init__(self, recording1, recording2, operator: str):
9+
assert operator in ["add", "subtract"], "Operator must be 'add' or 'subtract'"
10+
assert all(
11+
isinstance(rec, BaseRecording) for rec in [recording1, recording2]
12+
), "'recordings' must be a list of RecordingExtractor"
13+
14+
rec_attrs2 = get_rec_attributes(recording2)
15+
assert do_recording_attributes_match(
16+
recording1, rec_attrs2
17+
), "Both recordings must have the same sampling frequency and channel ids"
18+
assert self.are_times_kwargs_compatible(
19+
recording1, recording2
20+
), "Both recordings must have the same time parameters"
21+
22+
channel_ids = recording1.channel_ids
23+
sampling_frequency = recording1.sampling_frequency
24+
dtype = recording1.get_dtype()
25+
26+
BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype)
27+
28+
for segment1, segment2 in zip(recording1._recording_segments, recording2._recording_segments):
29+
add_segment = OperatorRecordingSegment(segment1, segment2, operator)
30+
self.add_recording_segment(add_segment)
31+
32+
self._kwargs = dict(recording1=recording1, recording2=recording2, operator=operator)
33+
34+
def are_times_kwargs_compatible(self, recording1, recording2) -> bool:
35+
import numpy as np
36+
37+
for segment_index in range(recording1.get_num_segments()):
38+
time_kwargs1 = recording1._recording_segments[segment_index].get_times_kwargs()
39+
time_kwargs2 = recording2._recording_segments[segment_index].get_times_kwargs()
40+
for key in time_kwargs1.keys():
41+
val1 = time_kwargs1[key]
42+
val2 = time_kwargs2[key]
43+
if (val1 is None and val2 is not None) or (val1 is not None and val2 is None):
44+
return False
45+
if isinstance(val1, np.ndarray) and isinstance(val2, np.ndarray):
46+
if not np.array_equal(val1, val2):
47+
return False
48+
else:
49+
if val1 != val2:
50+
return False
51+
return True
52+
53+
54+
class OperatorRecordingSegment(BaseRecordingSegment):
55+
def __init__(self, segment1, segment2, operator: str):
56+
BaseRecordingSegment.__init__(self, **segment1.get_times_kwargs())
57+
self.segment1 = segment1
58+
self.segment2 = segment2
59+
self.operator = operator
60+
61+
def get_num_samples(self):
62+
return self.segment1.get_num_samples()
63+
64+
def get_traces(self, start_frame, end_frame, channel_indices):
65+
traces1 = self.segment1.get_traces(start_frame, end_frame, channel_indices)
66+
traces2 = self.segment2.get_traces(start_frame, end_frame, channel_indices)
67+
if self.operator == "add":
68+
return traces1 + traces2
69+
elif self.operator == "subtract":
70+
return traces1 - traces2
71+
else:
72+
raise ValueError(f"Unknown operator: {self.operator}")
73+
74+
75+
class AddRecordings(BaseOperatorRecording):
76+
def __init__(self, recording1, recording2):
77+
BaseOperatorRecording.__init__(self, recording1, recording2, operator="add")
78+
79+
80+
class SubtractRecordings(BaseOperatorRecording):
81+
def __init__(self, recording1, recording2):
82+
BaseOperatorRecording.__init__(self, recording1, recording2, operator="subtract")
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
import numpy as np
3+
4+
import spikeinterface as si
5+
6+
7+
@pytest.fixture
8+
def recording():
9+
recording = si.generate_recording(durations=[10, 20], num_channels=4, sampling_frequency=20000)
10+
return recording
11+
12+
13+
def test_sum_recordings(recording):
14+
rec_sum = recording + recording
15+
for seg_index in range(rec_sum.get_num_segments()):
16+
traces_orig = recording.get_traces(segment_index=seg_index)
17+
traces_sum = rec_sum.get_traces(segment_index=seg_index)
18+
np.testing.assert_array_equal(traces_sum, traces_orig * 2)
19+
20+
21+
def test_subtract_recordings(recording):
22+
rec_sub = recording - recording
23+
for seg_index in range(rec_sub.get_num_segments()):
24+
traces_sub = rec_sub.get_traces(segment_index=seg_index)
25+
np.testing.assert_array_equal(traces_sub, np.zeros_like(traces_sub))
26+
27+
28+
def test_operator_combo(recording):
29+
rec_combo = recording - recording + recording - recording + recording
30+
for seg_index in range(rec_combo.get_num_segments()):
31+
traces_orig = recording.get_traces(segment_index=seg_index)
32+
traces_combo = rec_combo.get_traces(segment_index=seg_index)
33+
np.testing.assert_array_equal(traces_combo, traces_orig)
34+
35+
36+
def test_errors(recording):
37+
recording2 = si.generate_recording(durations=[10, 20], num_channels=4, sampling_frequency=10000)
38+
with pytest.raises(AssertionError):
39+
_ = recording + recording2
40+
with pytest.raises(AssertionError):
41+
_ = recording - recording2
42+
43+
recording_times = recording.clone()
44+
for segment_index in range(recording_times.get_num_segments()):
45+
recording_times.set_times(
46+
recording_times.get_times(segment_index=segment_index) + (segment_index + 1) * 5,
47+
segment_index=segment_index,
48+
)
49+
with pytest.raises(AssertionError):
50+
_ = recording + recording_times

0 commit comments

Comments
 (0)