From 53ee86a26e375b601f84df136f2546aab97f389e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 12:42:38 +0200 Subject: [PATCH 1/4] Correct handling of time in plot_traces --- src/spikeinterface/widgets/traces.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 86f2350a85..56808d44af 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -125,8 +125,10 @@ def __init__( if not rec0.has_time_vector(segment_index=segment_index): times = None - t_start = 0 - t_end = rec0.get_duration(segment_index=segment_index) + t_start = rec0.sample_index_to_time(0, segment_index=segment_index) + t_end = rec0.sample_index_to_time( + rec0.get_num_samples(segment_index=segment_index), segment_index=segment_index + ) else: times = rec0.get_times(segment_index=segment_index) t_start = times[0] @@ -673,13 +675,17 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_s assert all( rec.has_scaleable_traces() for rec in recordings.values() ), "Some recording layers do not have scaled traces. Use `return_scaled=False`" + frame_range = np.array( + [ + rec0.time_to_sample_index(time_range[0], segment_index=segment_index), + rec0.time_to_sample_index(time_range[1], segment_index=segment_index), + ] + ) if times is not None: - frame_range = np.searchsorted(times, time_range) times = times[frame_range[0] : frame_range[1]] else: - frame_range = (time_range * fs).astype("int64", copy=False) - a_max = rec0.get_num_frames(segment_index=segment_index) - frame_range = np.clip(frame_range, 0, a_max) + num_samples = rec0.get_num_samples(segment_index=segment_index) + frame_range = np.clip(frame_range, 0, num_samples) time_range = frame_range / fs times = np.arange(frame_range[0], frame_range[1]) / fs From e225fd5c79afabb3b8907467fe10bbf8f8d7b079 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 12:49:33 +0200 Subject: [PATCH 2/4] Add more checks and error messages --- src/spikeinterface/core/frameslicerecording.py | 4 +++- src/spikeinterface/widgets/traces.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index f2ef38e691..219aec301b 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -34,7 +34,9 @@ def __init__(self, parent_recording, start_frame=None, end_frame=None): if start_frame is None: start_frame = 0 else: - assert 0 <= start_frame < parent_size + assert ( + 0 <= start_frame < parent_size + ), f"'start_frame' must be fewer than number of samples in parent: {parent_size}" if end_frame is None: end_frame = parent_size diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 56808d44af..482139ca75 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -151,6 +151,11 @@ def __init__( ) time_range[1] = t_end + if time_range[0] < t_start or time_range[1] < t_start: + raise ValueError(f"All time_range values must be greater than {t_start}") + if time_range[1] <= time_range[0]: + raise ValueError("time_range[1] must be greater than time_range[0]") + assert mode in ("auto", "line", "map"), 'Mode must be one of "auto","line", "map"' if mode == "auto": if len(channel_ids) <= 64: From c9a0e642604610407836a243281c78d15f896b6d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 24 Oct 2024 11:36:27 +0200 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/frameslicerecording.py | 2 +- src/spikeinterface/widgets/traces.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 219aec301b..fdedf37266 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -36,7 +36,7 @@ def __init__(self, parent_recording, start_frame=None, end_frame=None): else: assert ( 0 <= start_frame < parent_size - ), f"'start_frame' must be fewer than number of samples in parent: {parent_size}" + ), f"`start_frame` must be fewer than number of samples in parent: {parent_size}" if end_frame is None: end_frame = parent_size diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 482139ca75..4ccb0671a2 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -152,9 +152,9 @@ def __init__( time_range[1] = t_end if time_range[0] < t_start or time_range[1] < t_start: - raise ValueError(f"All time_range values must be greater than {t_start}") + raise ValueError(f"All `time_range` values must be greater than {t_start}") if time_range[1] <= time_range[0]: - raise ValueError("time_range[1] must be greater than time_range[0]") + raise ValueError("`time_range[1]` must be greater than `time_range[0]`") assert mode in ("auto", "line", "map"), 'Mode must be one of "auto","line", "map"' if mode == "auto": From 0b6d49f7ebdd265ff9f24d4da368034d93844bf6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 24 Oct 2024 12:29:00 +0200 Subject: [PATCH 4/4] frame_range to sample_range --- src/spikeinterface/widgets/traces.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 4ccb0671a2..da649fd76a 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -165,7 +165,7 @@ def __init__( mode = mode cmap = cmap - times_in_range, list_traces, frame_range, channel_ids = _get_trace_list( + times_in_range, list_traces, sample_range, channel_ids = _get_trace_list( recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled, times=times ) @@ -257,7 +257,7 @@ def __init__( channel_ids=channel_ids, channel_locations=channel_locations, time_range=time_range, - frame_range=frame_range, + sample_range=sample_range, times_in_range=times_in_range, layer_keys=layer_keys, list_traces=list_traces, @@ -543,7 +543,7 @@ def _retrieve_traces(self, change=None): time_range = np.array([times[start_frame], times[end_frame]]) self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} - times_in_range, list_traces, frame_range, channel_ids = _get_trace_list( + times_in_range, list_traces, sample_range, channel_ids = _get_trace_list( self._selected_recordings, channel_ids, time_range, @@ -556,7 +556,7 @@ def _retrieve_traces(self, change=None): self._list_traces = list_traces self._times_in_range = times_in_range self._time_range = time_range - self._frame_range = (start_frame, end_frame) + self._sample_range = (start_frame, end_frame) self._segment_index = segment_index self._update_plot() @@ -569,7 +569,7 @@ def _update_plot(self, change=None): layer_keys = self._get_layers() data_plot["mode"] = mode - data_plot["frame_range"] = self._frame_range + data_plot["sample_range"] = self._sample_range data_plot["time_range"] = self._time_range if self.colorbar.value: data_plot["with_colorbar"] = True @@ -680,30 +680,29 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_s assert all( rec.has_scaleable_traces() for rec in recordings.values() ), "Some recording layers do not have scaled traces. Use `return_scaled=False`" - frame_range = np.array( + sample_range = np.array( [ rec0.time_to_sample_index(time_range[0], segment_index=segment_index), rec0.time_to_sample_index(time_range[1], segment_index=segment_index), ] ) if times is not None: - times = times[frame_range[0] : frame_range[1]] + times = times[sample_range[0] : sample_range[1]] else: num_samples = rec0.get_num_samples(segment_index=segment_index) - frame_range = np.clip(frame_range, 0, num_samples) - time_range = frame_range / fs - times = np.arange(frame_range[0], frame_range[1]) / fs + sample_range = np.clip(sample_range, 0, num_samples) + times = np.arange(sample_range[0], sample_range[1]) / fs list_traces = [] - for rec_name, rec in recordings.items(): + for _, rec in recordings.items(): traces = rec.get_traces( segment_index=segment_index, channel_ids=channel_ids, - start_frame=frame_range[0], - end_frame=frame_range[1], + start_frame=sample_range[0], + end_frame=sample_range[1], return_scaled=return_scaled, ) list_traces.append(traces) - return times, list_traces, frame_range, channel_ids + return times, list_traces, sample_range, channel_ids