Skip to content

Commit 68392ee

Browse files
committed
Revert gain/auto-scale changes and add check on smaller segmenr in cache
1 parent e3628b7 commit 68392ee

File tree

3 files changed

+78
-103
lines changed

3 files changed

+78
-103
lines changed

spikeinterface_gui/controller.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,25 @@ def get_traces(self, trace_source='preprocessed', **kargs):
570570
cache_key = (kargs.get("segment_index", None), kargs.get("start_frame", None), kargs.get("end_frame", None))
571571
if cache_key in self._traces_cached:
572572
return self._traces_cached[cache_key]
573+
else:
574+
# check if start_frame and end_frame are a subset interval of a cached one
575+
for cached_key in self._traces_cached.keys():
576+
cached_seg = cached_key[0]
577+
cached_start = cached_key[1]
578+
cached_end = cached_key[2]
579+
req_seg = kargs.get("segment_index", None)
580+
req_start = kargs.get("start_frame", None)
581+
req_end = kargs.get("end_frame", None)
582+
if cached_seg is not None and req_seg is not None:
583+
if cached_seg != req_seg:
584+
continue
585+
if cached_start is not None and cached_end is not None and req_start is not None and req_end is not None:
586+
if req_start >= cached_start and req_end <= cached_end:
587+
# subset found
588+
traces = self._traces_cached[cached_key]
589+
start_offset = req_start - cached_start
590+
end_offset = req_end - cached_start
591+
return traces[start_offset:end_offset, :]
573592

574593
if len(self._traces_cached) > 4:
575594
self._traces_cached.pop(list(self._traces_cached.keys())[0])

spikeinterface_gui/tracemapview.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ def make_color_lut(self):
5555
def get_visible_channel_inds(self):
5656
return self.channel_order
5757

58+
def apply_gain_zoom(self, factor_ratio):
59+
if self.color_limit is None:
60+
return
61+
self.color_limit = self.color_limit * factor_ratio
62+
self.refresh()
63+
64+
def auto_scale(self):
65+
if self.last_data_curves is not None:
66+
self.color_limit = np.max(np.abs(self.last_data_curves))
67+
self.refresh()
68+
5869
## Qt ##
5970
def _qt_make_layout(self, **kargs):
6071
from .myqt import QT
@@ -87,17 +98,6 @@ def _qt_make_layout(self, **kargs):
8798
#this do refresh
8899
self._qt_change_segment(0)
89100

90-
def _qt_gain_zoom(self, factor_ratio):
91-
if self.color_limit is None:
92-
return
93-
self.color_limit = self.color_limit * factor_ratio
94-
self.image.setLevels([-self.color_limit, self.color_limit])
95-
96-
def _qt_auto_scale(self):
97-
if self.last_data_curves is not None:
98-
self.color_limit = np.max(np.abs(self.last_data_curves))
99-
self._qt_gain_zoom(1.0)
100-
101101
def _qt_on_settings_changed(self, do_refresh=True):
102102

103103
self.spinbox_xsize.opts['bounds'] = [0.001, self.settings['xsize_max']]
@@ -312,15 +312,15 @@ def _panel_on_spike_selection_changed(self):
312312
self._panel_seek_with_selected_spike()
313313

314314
def _panel_gain_zoom(self, event):
315-
if event is None:
316-
factor_ratio = 1.0
317-
else:
318-
factor_ratio = 1.3 if event.delta > 0 else 1 / 1.3
315+
factor_ratio = 1.3 if event.delta > 0 else 1 / 1.3
319316
self.color_mapper.high = self.color_mapper.high * factor_ratio
320317
self.color_mapper.low = -self.color_mapper.high
321318

322319
def _panel_auto_scale(self, event):
323-
self._panel_gain_zoom(None)
320+
if self.last_data_curves is not None:
321+
self.color_limit = np.max(np.abs(self.last_data_curves))
322+
self.color_mapper.high = self.color_limit
323+
self.color_mapper.low = -self.color_limit
324324

325325
def _panel_on_time_info_updated(self):
326326
# Update segment and time slider range

spikeinterface_gui/traceview.py

Lines changed: 43 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ def get_data_in_chunk(self, t1, t2, segment_index):
5959
scatter_x = []
6060
scatter_y = []
6161
scatter_colors = []
62-
sample_indices = []
63-
local_channel_indices = []
6462

6563
global_to_local_chan_inds = np.zeros(self.controller.channel_ids.size, dtype='int64')
6664
global_to_local_chan_inds[visible_channel_inds] = np.arange(visible_channel_inds.size, dtype='int64')
@@ -99,15 +97,6 @@ def get_data_in_chunk(self, t1, t2, segment_index):
9997
scatter_x.extend(x)
10098
scatter_y.extend(y)
10199
scatter_colors.extend([color] * len(x))
102-
sample_indices.extend(sample_inds.tolist())
103-
local_channel_indices.extend(local_channel_inds.tolist())
104-
105-
self.last_scatter = dict(
106-
times=scatter_x,
107-
sample_indices=sample_indices,
108-
local_channel_indices=local_channel_indices,
109-
colors=scatter_colors,
110-
)
111100

112101
if not self._retrieve_traces_time_checked:
113102
t_traces_end = time.perf_counter()
@@ -150,7 +139,7 @@ def _qt_create_toolbar(self):
150139
self.spinbox_xsize.sigValueChanged.connect(self.refresh)
151140

152141
but = QT.QPushButton('auto scale')
153-
but.clicked.connect(self._qt_auto_scale)
142+
but.clicked.connect(self.auto_scale)
154143

155144
tb.addWidget(but)
156145

@@ -168,7 +157,7 @@ def _qt_initialize_plot(self):
168157

169158
self.viewBox.doubleclicked.connect(self._qt_scatter_item_clicked)
170159

171-
self.viewBox.gain_zoom.connect(self._qt_gain_zoom)
160+
self.viewBox.gain_zoom.connect(self.apply_gain_zoom)
172161
self.viewBox.xsize_zoom.connect(self._qt_xsize_zoom)
173162

174163
self.signals_curve = pg.PlotCurveItem(pen='#7FFF00', connect='finite')
@@ -189,11 +178,6 @@ def _qt_initialize_plot(self):
189178
self.gains = None
190179
self.offsets = None
191180

192-
def _qt_auto_scale(self):
193-
# to be defined in the child class
194-
pass
195-
196-
197181
def _qt_update_scroll_limits(self):
198182
seg_index = self.controller.get_time()[1]
199183
length = self.controller.get_num_samples(seg_index)
@@ -228,6 +212,8 @@ def _qt_on_combo_seg_changed(self):
228212

229213
def _qt_on_xsize_changed(self):
230214
xsize = self.spinbox_xsize.value()
215+
# Reset trace retrieval check: might require more or less time now!
216+
# self._retrieve_traces_time_checked = False
231217
self.xsize = xsize
232218
if not self._block_auto_refresh_and_notify:
233219
self.refresh()
@@ -308,10 +294,6 @@ def _panel_create_toolbar(self):
308294
value_throttled=0, sizing_mode="stretch_width")
309295
self.time_slider.param.watch(self._panel_on_time_slider_changed, "value_throttled")
310296

311-
def _panel_auto_scale(self, event):
312-
# to be defined in the child class
313-
pass
314-
315297
def _panel_on_segment_changed(self, event):
316298
seg_index = int(event.new.split()[-1])
317299
self._panel_change_segment(seg_index)
@@ -331,6 +313,8 @@ def _panel_change_segment(self, seg_index):
331313

332314
def _panel_on_xsize_changed(self, event):
333315
self.xsize = event.new
316+
# Reset trace retrieval check: might require more or less time now!
317+
# self._retrieve_traces_time_checked = False
334318
if not self._block_auto_refresh_and_notify:
335319
self.refresh()
336320
self.notify_time_info_updated()
@@ -373,6 +357,25 @@ def _panel_seek_with_selected_spike(self):
373357
self.refresh()
374358
self.notify_time_info_updated()
375359

360+
# TODO: pan behavior like Qt?
361+
# def _panel_on_pan_start(self, event):
362+
# self.drag_state["x_start"] = event.x
363+
364+
# def _panel_on_pan(self, event):
365+
# print("Panning...")
366+
# if self.drag_state["x_start"] is None:
367+
# return
368+
# delta = event.x - self.drag_state["x_start"]
369+
# print(f"Delta: {delta}")
370+
# factor = 1.0 - (delta * 10) # adjust sensitivity
371+
# factor = max(0.5, min(factor, 2.0)) # limit zoom factor
372+
# print(f"Change xsize by factor: {factor}. From {self.xsize} to {self.xsize * factor}")
373+
374+
# self.xsize_spinner.value = self.xsize * factor
375+
376+
# def _panel_on_pan_end(self, event):
377+
# self.drag_state["x_start"] = None
378+
376379

377380
class TraceView(ViewBase, MixinViewTrace):
378381
_supported_backend = ['qt', 'panel']
@@ -417,39 +420,21 @@ def get_visible_channel_inds(self):
417420
inds = inds[:n_max]
418421
return inds
419422

420-
## qt ##
421-
def _qt_gain_zoom(self, factor_ratio):
422-
self.factor *= factor_ratio
423-
visible_channel_inds = self.get_visible_channel_inds()
424-
if len(visible_channel_inds) == 0:
425-
return
426-
data_curves = self.last_data_curves.copy()
427-
428-
if self.factor is not None:
429-
n = visible_channel_inds.size
430-
gains = np.ones(n, dtype=float) * 1.0 / (self.factor * max(self.mad[visible_channel_inds]))
431-
offsets = np.arange(n)[::-1] - self.med[visible_channel_inds] * gains
432-
433-
data_curves *= gains[:, None]
434-
data_curves += offsets[:, None]
435-
connect = np.ones(data_curves.shape, dtype='bool')
436-
connect[:, -1] = 0
437-
times_chunk_tile = np.tile(self.last_times_chunk, visible_channel_inds.size)
438-
self.signals_curve.setData(times_chunk_tile, data_curves.flatten(), connect=connect.flatten())
439-
440-
if self.last_scatter is not None:
441-
scatter_x = self.last_scatter['times']
442-
if len(scatter_x) > 0:
443-
local_channel_inds = self.last_scatter['local_channel_indices']
444-
sample_inds = self.last_scatter['sample_indices']
445-
scatter_y = data_curves[local_channel_inds, sample_inds]
446-
scatter_colors = self.last_scatter['colors']
447-
self.scatter.setData(x=scatter_x, y=scatter_y, brush=scatter_colors)
448-
449-
def _qt_auto_scale(self):
423+
def auto_scale(self):
450424
self.factor = 15.
451-
self._qt_gain_zoom(1.0)
425+
trace_context = self.trace_context
426+
self.trace_context = nullcontext
427+
self.refresh()
428+
self.trace_context = trace_context
429+
430+
def apply_gain_zoom(self, factor_ratio):
431+
self.factor *= factor_ratio
432+
trace_context = self.trace_context
433+
self.trace_context = nullcontext
434+
self.refresh()
435+
self.trace_context = trace_context
452436

437+
## qt ##
453438
def _qt_make_layout(self):
454439
from .myqt import QT
455440
import pyqtgraph as pg
@@ -583,12 +568,12 @@ def _panel_make_layout(self):
583568
import panel as pn
584569
import bokeh.plotting as bpl
585570
from .utils_panel import _bg_color
586-
from bokeh.models import ColumnDataSource, Range1d, HoverTool
571+
from bokeh.models import ColumnDataSource, Range1d
587572
from bokeh.events import Tap, MouseWheel
588573

589574
self.figure = bpl.figure(
590575
sizing_mode="stretch_both",
591-
tools="box_zoom,reset",
576+
tools="reset",
592577
background_fill_color=_bg_color,
593578
border_fill_color=_bg_color,
594579
outline_line_color="white",
@@ -699,40 +684,11 @@ def _panel_gain_zoom(self, event):
699684
factor_ratio = 1.3 if event.delta > 0 else 1 / 1.3
700685
else:
701686
factor_ratio = 1.0
702-
self.factor *= factor_ratio
703-
visible_channel_inds = self.get_visible_channel_inds()
704-
data_curves = self.last_data_curves.copy()
705-
if len(visible_channel_inds) == 0:
706-
return
707-
if self.factor is not None:
708-
n = visible_channel_inds.size
709-
gains = np.ones(n, dtype=float) * 1.0 / (self.factor * max(self.mad[visible_channel_inds]))
710-
offsets = np.arange(n)[::-1] - self.med[visible_channel_inds] * gains
711-
712-
data_curves *= gains[:, None]
713-
data_curves += offsets[:, None]
714-
715-
self.signal_source.data.update(
716-
{
717-
"ys": [data_curves[i, :] for i in range(n)],
718-
}
719-
)
720-
721-
if self.last_scatter is not None:
722-
scatter_x = self.last_scatter['times']
723-
if len(scatter_x) > 0:
724-
local_channel_inds = self.last_scatter['local_channel_indices']
725-
sample_inds = self.last_scatter['sample_indices']
726-
scatter_y = data_curves[local_channel_inds, sample_inds]
727-
self.spike_source.data.update(
728-
{
729-
"y": scatter_y,
730-
}
731-
)
687+
factor = 1.3 if event.delta > 0 else 1 / 1.3
688+
self.apply_gain_zoom(factor)
732689

733690
def _panel_auto_scale(self, event):
734-
self.factor = 15.0
735-
self._panel_gain_zoom(None)
691+
self.auto_scale()
736692

737693
def _panel_on_time_info_updated(self):
738694
# Update segment and time slider range

0 commit comments

Comments
 (0)