Skip to content

Commit 0647e98

Browse files
authored
Merge pull request #188 from alejoe91/busy-indicator
Add context manager for busy indicator
2 parents 38d3bc4 + 2a4d045 commit 0647e98

File tree

6 files changed

+240
-188
lines changed

6 files changed

+240
-188
lines changed

spikeinterface_gui/controller.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,14 +570,33 @@ def get_traces(self, trace_source='preprocessed', **kargs):
570570
cache_key = (kargs.get("segment_index", None), kargs.get("start_frame", None), kargs.get("end_frame", None))
571571
if cache_key in self._traces_cached:
572572
return self._traces_cached[cache_key]
573+
else:
574+
# check if start_frame and end_frame are a subset interval of a cached one
575+
for cached_key in self._traces_cached.keys():
576+
cached_seg = cached_key[0]
577+
cached_start = cached_key[1]
578+
cached_end = cached_key[2]
579+
req_seg = kargs.get("segment_index", None)
580+
req_start = kargs.get("start_frame", None)
581+
req_end = kargs.get("end_frame", None)
582+
if cached_seg is not None and req_seg is not None:
583+
if cached_seg != req_seg:
584+
continue
585+
if cached_start is not None and cached_end is not None and req_start is not None and req_end is not None:
586+
if req_start >= cached_start and req_end <= cached_end:
587+
# subset found
588+
traces = self._traces_cached[cached_key]
589+
start_offset = req_start - cached_start
590+
end_offset = req_end - cached_start
591+
return traces[start_offset:end_offset, :]
573592

574593
if len(self._traces_cached) > 4:
575594
self._traces_cached.pop(list(self._traces_cached.keys())[0])
576595

577596
if trace_source == 'preprocessed':
578597
rec = self.analyzer.recording
579598
elif trace_source == 'raw':
580-
raise NotImplemented
599+
raise NotImplementedError("Raw traces not implemented yet")
581600
# TODO get with parent recording the non process recording
582601
kargs['return_in_uV'] = self.return_in_uV
583602
traces = rec.get_traces(**kargs)

spikeinterface_gui/mergeview.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _qt_make_layout(self):
198198
row_layout = QT.QHBoxLayout()
199199

200200
but = QT.QPushButton('Calculate merges')
201-
but.clicked.connect(self._qt_calculate_potential_automerge)
201+
but.clicked.connect(self._compute_merges)
202202
row_layout.addWidget(but)
203203

204204
if self.controller.curation:
@@ -269,8 +269,11 @@ def _qt_refresh(self):
269269
self.table.resizeColumnToContents(i)
270270
self.table.setSortingEnabled(True)
271271

272-
def _qt_calculate_potential_automerge(self):
273-
self.get_potential_merges()
272+
def _compute_merges(self):
273+
with self.busy_cursor():
274+
self.get_potential_merges()
275+
if len(self.proposed_merge_unit_groups) == 0:
276+
self.warning(f"No potential merges found with method {self.method}")
274277
self.refresh()
275278

276279
def _qt_on_spike_selection_changed(self):
@@ -318,7 +321,7 @@ def _panel_make_layout(self):
318321
self.table_area = pn.pane.Placeholder("No merges computed yet.", height=400)
319322

320323
self.caluculate_merges_button = pn.widgets.Button(name="Calculate merges", button_type="primary", sizing_mode="stretch_width")
321-
self.caluculate_merges_button.on_click(self._panel_calculate_merges)
324+
self.caluculate_merges_button.on_click(self._panel_compute_merges)
322325

323326
calculate_list = [self.caluculate_merges_button]
324327

@@ -378,11 +381,8 @@ def _panel_refresh(self):
378381
self.table.on_click(self._panel_on_click)
379382
self.table_area.update(self.table)
380383

381-
def _panel_calculate_merges(self, event):
382-
import panel as pn
383-
self.table_area.update(pn.indicators.LoadingSpinner(size=50, value=True))
384-
self.get_potential_merges()
385-
self.refresh()
384+
def _panel_compute_merges(self, event):
385+
self._compute_merges()
386386

387387
def _panel_on_method_change(self, event):
388388
self.method = event.new

spikeinterface_gui/tracemapview.py

Lines changed: 16 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
2-
3-
import matplotlib.cm
2+
from contextlib import nullcontext
43
import matplotlib.colors
54

65
from .view_base import ViewBase
@@ -29,16 +28,17 @@ def __init__(self, controller=None, parent=None, backend="qt"):
2928
self.channel_order_reverse = np.argsort(self.channel_order, kind="stable")
3029
self.color_limit = None
3130
self.last_data_curves = None
31+
self.factor = None
3232

3333
self.xsize = 0.5
3434
self._block_auto_refresh_and_notify = False
35+
self.trace_context = nullcontext
3536

3637
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
3738
MixinViewTrace.__init__(self)
3839

3940
self.make_color_lut()
4041

41-
4242
def apply_gain_zoom(self, factor_ratio):
4343
if self.color_limit is None:
4444
return
@@ -62,61 +62,8 @@ def make_color_lut(self):
6262
if self.settings['reverse_colormap']:
6363
self.lut = self.lut[::-1]
6464

65-
66-
def get_data_in_chunk(self, t1, t2, segment_index):
67-
t_start = 0.0
68-
sr = self.controller.sampling_frequency
69-
70-
ind1 = max(0, int((t1 - t_start) * sr))
71-
ind2 = min(self.controller.get_num_samples(segment_index), int((t2 - t_start) * sr))
72-
73-
traces_chunk = self.controller.get_traces(segment_index=segment_index, start_frame=ind1, end_frame=ind2)
74-
75-
sl = self.controller.segment_slices[segment_index]
76-
spikes_seg = self.controller.spikes[sl]
77-
i1, i2 = np.searchsorted(spikes_seg["sample_index"], [ind1, ind2])
78-
spikes_chunk = spikes_seg[i1:i2].copy()
79-
spikes_chunk["sample_index"] -= ind1
80-
81-
data_curves = traces_chunk[:, self.channel_order]
82-
83-
if data_curves.dtype != "float32":
84-
data_curves = data_curves.astype("float32")
85-
86-
times_chunk = np.arange(traces_chunk.shape[0], dtype='float64')/self.controller.sampling_frequency+max(t1, 0)
87-
88-
scatter_x = []
89-
scatter_y = []
90-
scatter_colors = []
91-
scatter_unit_ids = []
92-
93-
for unit_index, unit_id in self.controller.iter_visible_units():
94-
95-
inds = np.flatnonzero(spikes_chunk["unit_index"] == unit_index)
96-
if inds.size == 0:
97-
continue
98-
99-
# Get spikes for this unit
100-
unit_spikes = spikes_chunk[inds]
101-
channel_inds = unit_spikes["channel_index"]
102-
sample_inds = unit_spikes["sample_index"]
103-
104-
x = times_chunk[sample_inds]
105-
y = self.channel_order_reverse[channel_inds] + 0.5
106-
107-
# This should both for qt (QTColor) and panel (html color)
108-
color = self.get_unit_color(unit_id)
109-
110-
scatter_x.extend(x)
111-
scatter_y.extend(y)
112-
scatter_colors.extend([color] * len(x))
113-
scatter_unit_ids.extend([str(unit_id)] * len(x))
114-
115-
# used for auto scaled
116-
self.last_data_curves = data_curves
117-
118-
return times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids
119-
65+
def get_visible_channel_inds(self):
66+
return self.channel_order
12067

12168
## Qt ##
12269
def _qt_make_layout(self, **kargs):
@@ -149,7 +96,6 @@ def _qt_make_layout(self, **kargs):
14996
# self.on_params_changed(do_refresh=False)
15097
#this do refresh
15198
self._qt_change_segment(0)
152-
15399

154100
def _qt_on_settings_changed(self, do_refresh=True):
155101

@@ -197,9 +143,10 @@ def _qt_seek(self, t):
197143
self.scroll_time.valueChanged.connect(self._qt_on_scroll_time)
198144

199145
seg_index = self.controller.get_time()[1]
200-
times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \
146+
times_chunk, data_curves, scatter_x, scatter_y, scatter_colors = \
201147
self.get_data_in_chunk(t1, t2, seg_index)
202-
148+
data_curves = data_curves.T
149+
203150
if self.color_limit is None:
204151
self.color_limit = np.max(np.abs(data_curves))
205152

@@ -212,8 +159,8 @@ def _qt_seek(self, t):
212159
# self.scatter.clear()
213160
self.scatter.setData(x=scatter_x, y=scatter_y, brush=scatter_colors)
214161

215-
self.plot.setXRange( t1, t2, padding = 0.0)
216-
self.plot.setYRange(0, num_chans, padding = 0.0)
162+
self.plot.setXRange(t1, t2, padding=0.0)
163+
self.plot.setYRange(0, num_chans, padding=0.0)
217164

218165
def _qt_on_time_info_updated(self):
219166
# Update segment and time slider range
@@ -275,7 +222,7 @@ def _panel_make_layout(self):
275222
# Add data sources
276223
self.image_source = ColumnDataSource({"image": [], "x": [], "y": [], "dw": [], "dh": []})
277224

278-
self.spike_source = ColumnDataSource({"x": [], "y": [], "color": [], "unit_id": []})
225+
self.spike_source = ColumnDataSource({"x": [], "y": [], "color": []})
279226

280227
# Create color mapper
281228
self.color_mapper = LinearColorMapper(palette="Greys256", low=-1, high=1)
@@ -317,8 +264,9 @@ def _panel_refresh(self):
317264
else:
318265
auto_scale = False
319266

320-
times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \
267+
times_chunk, data_curves, scatter_x, scatter_y, scatter_colors = \
321268
self.get_data_in_chunk(t1, t2, seg_index)
269+
data_curves = data_curves.T
322270

323271
if self.color_limit is None:
324272
self.color_limit = np.max(np.abs(data_curves))
@@ -335,7 +283,6 @@ def _panel_refresh(self):
335283
"x": scatter_x,
336284
"y": scatter_y,
337285
"color": scatter_colors,
338-
"unit_id": scatter_unit_ids,
339286
})
340287

341288
if auto_scale:
@@ -350,7 +297,7 @@ def _panel_refresh(self):
350297
# TODO: if from a different unit, change unit visibility
351298
def _panel_on_tap(self, event):
352299
seg_index = self.controller.get_time()[1]
353-
ind_spike_nearest = self.find_nearest_spike(self.controller, event.x, seg_index)
300+
ind_spike_nearest = find_nearest_spike(self.controller, event.x, seg_index)
354301
if ind_spike_nearest is not None:
355302
self.controller.set_indices_spike_selected([ind_spike_nearest])
356303
self._panel_seek_with_selected_spike()
@@ -364,8 +311,8 @@ def _panel_on_spike_selection_changed(self):
364311
self._panel_seek_with_selected_spike()
365312

366313
def _panel_gain_zoom(self, event):
367-
factor = 1.3 if event.delta > 0 else 1 / 1.3
368-
self.color_mapper.high = self.color_mapper.high * factor
314+
factor_ratio = 1.3 if event.delta > 0 else 1 / 1.3
315+
self.color_mapper.high = self.color_mapper.high * factor_ratio
369316
self.color_mapper.low = -self.color_mapper.high
370317

371318
def _panel_auto_scale(self, event):

0 commit comments

Comments
 (0)