|
| 1 | +from spikeinterface.core import BaseRecording, BaseRecordingSegment |
| 2 | +from spikeinterface.core.recording_tools import get_rec_attributes, do_recording_attributes_match |
| 3 | + |
| 4 | + |
| 5 | +class BaseOperatorRecording(BaseRecording): |
| 6 | + """Base class for operator recordings.""" |
| 7 | + |
| 8 | + def __init__(self, recording1, recording2, operator: str): |
| 9 | + assert operator in ["add", "subtract"], "Operator must be 'add' or 'subtract'" |
| 10 | + assert all( |
| 11 | + isinstance(rec, BaseRecording) for rec in [recording1, recording2] |
| 12 | + ), "'recordings' must be a list of RecordingExtractor" |
| 13 | + |
| 14 | + rec_attrs2 = get_rec_attributes(recording2) |
| 15 | + assert do_recording_attributes_match( |
| 16 | + recording1, rec_attrs2 |
| 17 | + ), "Both recordings must have the same sampling frequency and channel ids" |
| 18 | + assert self.are_times_kwargs_compatible( |
| 19 | + recording1, recording2 |
| 20 | + ), "Both recordings must have the same time parameters" |
| 21 | + |
| 22 | + channel_ids = recording1.channel_ids |
| 23 | + sampling_frequency = recording1.sampling_frequency |
| 24 | + dtype = recording1.get_dtype() |
| 25 | + |
| 26 | + BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) |
| 27 | + |
| 28 | + for segment1, segment2 in zip(recording1._recording_segments, recording2._recording_segments): |
| 29 | + add_segment = OperatorRecordingSegment(segment1, segment2, operator) |
| 30 | + self.add_recording_segment(add_segment) |
| 31 | + |
| 32 | + self._kwargs = dict(recording1=recording1, recording2=recording2, operator=operator) |
| 33 | + |
| 34 | + def are_times_kwargs_compatible(self, recording1, recording2) -> bool: |
| 35 | + import numpy as np |
| 36 | + |
| 37 | + for segment_index in range(recording1.get_num_segments()): |
| 38 | + time_kwargs1 = recording1._recording_segments[segment_index].get_times_kwargs() |
| 39 | + time_kwargs2 = recording2._recording_segments[segment_index].get_times_kwargs() |
| 40 | + for key in time_kwargs1.keys(): |
| 41 | + val1 = time_kwargs1[key] |
| 42 | + val2 = time_kwargs2[key] |
| 43 | + if (val1 is None and val2 is not None) or (val1 is not None and val2 is None): |
| 44 | + return False |
| 45 | + if isinstance(val1, np.ndarray) and isinstance(val2, np.ndarray): |
| 46 | + if not np.array_equal(val1, val2): |
| 47 | + return False |
| 48 | + else: |
| 49 | + if val1 != val2: |
| 50 | + return False |
| 51 | + return True |
| 52 | + |
| 53 | + |
| 54 | +class OperatorRecordingSegment(BaseRecordingSegment): |
| 55 | + def __init__(self, segment1, segment2, operator: str): |
| 56 | + BaseRecordingSegment.__init__(self, **segment1.get_times_kwargs()) |
| 57 | + self.segment1 = segment1 |
| 58 | + self.segment2 = segment2 |
| 59 | + self.operator = operator |
| 60 | + |
| 61 | + def get_num_samples(self): |
| 62 | + return self.segment1.get_num_samples() |
| 63 | + |
| 64 | + def get_traces(self, start_frame, end_frame, channel_indices): |
| 65 | + traces1 = self.segment1.get_traces(start_frame, end_frame, channel_indices) |
| 66 | + traces2 = self.segment2.get_traces(start_frame, end_frame, channel_indices) |
| 67 | + if self.operator == "add": |
| 68 | + return traces1 + traces2 |
| 69 | + elif self.operator == "subtract": |
| 70 | + return traces1 - traces2 |
| 71 | + else: |
| 72 | + raise ValueError(f"Unknown operator: {self.operator}") |
| 73 | + |
| 74 | + |
| 75 | +class AddRecordings(BaseOperatorRecording): |
| 76 | + def __init__(self, recording1, recording2): |
| 77 | + BaseOperatorRecording.__init__(self, recording1, recording2, operator="add") |
| 78 | + |
| 79 | + |
| 80 | +class SubtractRecordings(BaseOperatorRecording): |
| 81 | + def __init__(self, recording1, recording2): |
| 82 | + BaseOperatorRecording.__init__(self, recording1, recording2, operator="subtract") |
0 commit comments