Skip to content

Commit d95f3ea

Browse files
committed
Add context manager for busy indicator and use for traces, merges, compute
1 parent 0443dcb commit d95f3ea

File tree

4 files changed

+140
-143
lines changed

4 files changed

+140
-143
lines changed

spikeinterface_gui/mergeview.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _qt_make_layout(self):
198198
row_layout = QT.QHBoxLayout()
199199

200200
but = QT.QPushButton('Calculate merges')
201-
but.clicked.connect(self._qt_calculate_potential_automerge)
201+
but.clicked.connect(self._compute_merges)
202202
row_layout.addWidget(but)
203203

204204
if self.controller.curation:
@@ -269,8 +269,11 @@ def _qt_refresh(self):
269269
self.table.resizeColumnToContents(i)
270270
self.table.setSortingEnabled(True)
271271

272-
def _qt_calculate_potential_automerge(self):
273-
self.get_potential_merges()
272+
def _compute_merges(self):
273+
with self.busy_cursor():
274+
self.get_potential_merges()
275+
if len(self.proposed_merge_unit_groups) == 0:
276+
self.warning(f"No potential merges found with method {self.method}")
274277
self.refresh()
275278

276279
def _qt_on_spike_selection_changed(self):
@@ -318,7 +321,7 @@ def _panel_make_layout(self):
318321
self.table_area = pn.pane.Placeholder("No merges computed yet.", height=400)
319322

320323
self.caluculate_merges_button = pn.widgets.Button(name="Calculate merges", button_type="primary", sizing_mode="stretch_width")
321-
self.caluculate_merges_button.on_click(self._panel_calculate_merges)
324+
self.caluculate_merges_button.on_click(self._panel_compute_merges)
322325

323326
calculate_list = [self.caluculate_merges_button]
324327

@@ -378,11 +381,8 @@ def _panel_refresh(self):
378381
self.table.on_click(self._panel_on_click)
379382
self.table_area.update(self.table)
380383

381-
def _panel_calculate_merges(self, event):
382-
import panel as pn
383-
self.table_area.update(pn.indicators.LoadingSpinner(size=50, value=True))
384-
self.get_potential_merges()
385-
self.refresh()
384+
def _panel_compute_merges(self, event):
385+
self._compute_merges()
386386

387387
def _panel_on_method_change(self, event):
388388
self.method = event.new

spikeinterface_gui/tracemapview.py

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
2-
3-
import matplotlib.cm
2+
from contextlib import nullcontext
43
import matplotlib.colors
54

65
from .view_base import ViewBase
@@ -29,9 +28,12 @@ def __init__(self, controller=None, parent=None, backend="qt"):
2928
self.channel_order_reverse = np.argsort(self.channel_order, kind="stable")
3029
self.color_limit = None
3130
self.last_data_curves = None
31+
self.factor = None
3232

3333
self.xsize = 0.5
3434
self._block_auto_refresh_and_notify = False
35+
self._retrieve_traces_time_checked = None
36+
self.trace_context = nullcontext
3537

3638
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
3739
MixinViewTrace.__init__(self)
@@ -62,60 +64,8 @@ def make_color_lut(self):
6264
if self.settings['reverse_colormap']:
6365
self.lut = self.lut[::-1]
6466

65-
66-
def get_data_in_chunk(self, t1, t2, segment_index):
67-
t_start = 0.0
68-
sr = self.controller.sampling_frequency
69-
70-
ind1 = max(0, int((t1 - t_start) * sr))
71-
ind2 = min(self.controller.get_num_samples(segment_index), int((t2 - t_start) * sr))
72-
73-
traces_chunk = self.controller.get_traces(segment_index=segment_index, start_frame=ind1, end_frame=ind2)
74-
75-
sl = self.controller.segment_slices[segment_index]
76-
spikes_seg = self.controller.spikes[sl]
77-
i1, i2 = np.searchsorted(spikes_seg["sample_index"], [ind1, ind2])
78-
spikes_chunk = spikes_seg[i1:i2].copy()
79-
spikes_chunk["sample_index"] -= ind1
80-
81-
data_curves = traces_chunk[:, self.channel_order]
82-
83-
if data_curves.dtype != "float32":
84-
data_curves = data_curves.astype("float32")
85-
86-
times_chunk = np.arange(traces_chunk.shape[0], dtype='float64')/self.controller.sampling_frequency+max(t1, 0)
87-
88-
scatter_x = []
89-
scatter_y = []
90-
scatter_colors = []
91-
scatter_unit_ids = []
92-
93-
for unit_index, unit_id in self.controller.iter_visible_units():
94-
95-
inds = np.flatnonzero(spikes_chunk["unit_index"] == unit_index)
96-
if inds.size == 0:
97-
continue
98-
99-
# Get spikes for this unit
100-
unit_spikes = spikes_chunk[inds]
101-
channel_inds = unit_spikes["channel_index"]
102-
sample_inds = unit_spikes["sample_index"]
103-
104-
x = times_chunk[sample_inds]
105-
y = self.channel_order_reverse[channel_inds] + 0.5
106-
107-
# This should both for qt (QTColor) and panel (html color)
108-
color = self.get_unit_color(unit_id)
109-
110-
scatter_x.extend(x)
111-
scatter_y.extend(y)
112-
scatter_colors.extend([color] * len(x))
113-
scatter_unit_ids.extend([str(unit_id)] * len(x))
114-
115-
# used for auto scaled
116-
self.last_data_curves = data_curves
117-
118-
return times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids
67+
def get_visible_channel_inds(self):
68+
return np.arange(self.controller.analyzer.get_num_channels())
11969

12070

12171
## Qt ##
@@ -199,6 +149,7 @@ def _qt_seek(self, t):
199149
seg_index = self.controller.get_time()[1]
200150
times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids = \
201151
self.get_data_in_chunk(t1, t2, seg_index)
152+
self.last_data_curves = data_curves
202153

203154
if self.color_limit is None:
204155
self.color_limit = np.max(np.abs(data_curves))

spikeinterface_gui/traceview.py

Lines changed: 91 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import time
3+
from contextlib import nullcontext
34

45
from .view_base import ViewBase
56

@@ -13,6 +14,94 @@
1314
# *
1415

1516
class MixinViewTrace:
17+
18+
MAX_RETRIEVE_TIME_FOR_BUSY_CURSOR = 0.5 # seconds
19+
20+
def get_data_in_chunk(self, t1, t2, segment_index):
21+
with self.trace_context():
22+
if not self._retrieve_traces_time_checked:
23+
t_traces_start = time.perf_counter()
24+
t_start = 0.0
25+
sr = self.controller.sampling_frequency
26+
27+
ind1 = max(0, int((t1 - t_start) * sr))
28+
ind2 = min(self.controller.get_num_samples(segment_index), int((t2 - t_start) * sr))
29+
30+
traces_chunk = self.controller.get_traces(segment_index=segment_index, start_frame=ind1, end_frame=ind2)
31+
32+
sl = self.controller.segment_slices[segment_index]
33+
spikes_seg = self.controller.spikes[sl]
34+
i1, i2 = np.searchsorted(spikes_seg["sample_index"], [ind1, ind2])
35+
spikes_chunk = spikes_seg[i1:i2].copy()
36+
spikes_chunk["sample_index"] -= ind1
37+
38+
visible_channel_inds = self.get_visible_channel_inds()
39+
40+
data_curves = traces_chunk[:, visible_channel_inds].T.copy()
41+
42+
if data_curves.dtype != "float32":
43+
data_curves = data_curves.astype("float32")
44+
45+
if self.factor is not None:
46+
n = visible_channel_inds.size
47+
gains = np.ones(n, dtype=float) * 1.0 / (self.factor * max(self.mad[visible_channel_inds]))
48+
offsets = np.arange(n)[::-1] - self.med[visible_channel_inds] * gains
49+
50+
data_curves *= gains[:, None]
51+
data_curves += offsets[:, None]
52+
53+
times_chunk = np.arange(traces_chunk.shape[0], dtype='float64') / self.controller.sampling_frequency+max(t1, 0)
54+
55+
scatter_x = []
56+
scatter_y = []
57+
scatter_colors = []
58+
scatter_unit_ids = []
59+
60+
global_to_local_chan_inds = np.zeros(self.controller.channel_ids.size, dtype='int64')
61+
global_to_local_chan_inds[visible_channel_inds] = np.arange(visible_channel_inds.size, dtype='int64')
62+
63+
for unit_index, unit_id in self.controller.iter_visible_units():
64+
65+
inds = np.flatnonzero(spikes_chunk["unit_index"] == unit_index)
66+
if inds.size == 0:
67+
continue
68+
69+
# Get spikes for this unit
70+
unit_spikes = spikes_chunk[inds]
71+
channel_inds = unit_spikes["channel_index"]
72+
sample_inds = unit_spikes["sample_index"]
73+
74+
chan_mask = np.isin(channel_inds, visible_channel_inds)
75+
if not np.any(chan_mask):
76+
continue
77+
78+
sample_inds = sample_inds[chan_mask]
79+
channel_inds = channel_inds[chan_mask]
80+
# Map channel indices to their positions in visible_channel_inds
81+
local_channel_inds = global_to_local_chan_inds[channel_inds]
82+
83+
84+
# Calculate y values using signal values
85+
x = times_chunk[sample_inds]
86+
y = data_curves[local_channel_inds, sample_inds]
87+
88+
# This should both for qt (QTColor) and panel (html color)
89+
color = self.get_unit_color(unit_id)
90+
91+
scatter_x.extend(x)
92+
scatter_y.extend(y)
93+
scatter_colors.extend([color] * len(x))
94+
scatter_unit_ids.extend([str(unit_id)] * len(x))
95+
if not self._retrieve_traces_time_checked:
96+
t_traces_end = time.perf_counter()
97+
elapsed = t_traces_end - t_traces_start
98+
if elapsed > self.MAX_RETRIEVE_TIME_FOR_BUSY_CURSOR:
99+
print(f"Trace retrieval took {elapsed:.3f} seconds. Enabling busy cursor.")
100+
self.trace_context = self.busy_cursor
101+
self._retrieve_traces_time_checked = True
102+
103+
return times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids
104+
16105
## Qt ##
17106
def _qt_create_toolbar(self):
18107
from .myqt import QT
@@ -284,6 +373,8 @@ def __init__(self, controller=None, parent=None, backend="qt"):
284373
self.factor = 15.0
285374
self.xsize = 0.5
286375
self._block_auto_refresh_and_notify = False
376+
self._retrieve_traces_time_checked = False
377+
self.trace_context = nullcontext
287378

288379
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
289380
MixinViewTrace.__init__(self)
@@ -308,82 +399,6 @@ def apply_gain_zoom(self, factor_ratio):
308399
self.factor *= factor_ratio
309400
self.refresh()
310401

311-
def get_data_in_chunk(self, t1, t2, segment_index):
312-
t_start = 0.0
313-
sr = self.controller.sampling_frequency
314-
315-
ind1 = max(0, int((t1 - t_start) * sr))
316-
ind2 = min(self.controller.get_num_samples(segment_index), int((t2 - t_start) * sr))
317-
318-
traces_chunk = self.controller.get_traces(segment_index=segment_index, start_frame=ind1, end_frame=ind2)
319-
320-
sl = self.controller.segment_slices[segment_index]
321-
spikes_seg = self.controller.spikes[sl]
322-
i1, i2 = np.searchsorted(spikes_seg["sample_index"], [ind1, ind2])
323-
spikes_chunk = spikes_seg[i1:i2].copy()
324-
spikes_chunk["sample_index"] -= ind1
325-
326-
visible_channel_inds = self.get_visible_channel_inds()
327-
328-
data_curves = traces_chunk[:, visible_channel_inds].T.copy()
329-
330-
if data_curves.dtype != "float32":
331-
data_curves = data_curves.astype("float32")
332-
333-
n = visible_channel_inds.size
334-
gains = np.ones(n, dtype=float) * 1.0 / (self.factor * max(self.mad[visible_channel_inds]))
335-
offsets = np.arange(n)[::-1] - self.med[visible_channel_inds] * gains
336-
337-
data_curves *= gains[:, None]
338-
data_curves += offsets[:, None]
339-
340-
times_chunk = np.arange(traces_chunk.shape[0], dtype='float64')/self.controller.sampling_frequency+max(t1, 0)
341-
342-
scatter_x = []
343-
scatter_y = []
344-
scatter_colors = []
345-
scatter_unit_ids = []
346-
347-
global_to_local_chan_inds = np.zeros(self.controller.channel_ids.size, dtype='int64')
348-
global_to_local_chan_inds[visible_channel_inds] = np.arange(visible_channel_inds.size, dtype='int64')
349-
350-
351-
for unit_index, unit_id in self.controller.iter_visible_units():
352-
353-
inds = np.flatnonzero(spikes_chunk["unit_index"] == unit_index)
354-
if inds.size == 0:
355-
continue
356-
357-
# Get spikes for this unit
358-
unit_spikes = spikes_chunk[inds]
359-
channel_inds = unit_spikes["channel_index"]
360-
sample_inds = unit_spikes["sample_index"]
361-
362-
363-
chan_mask = np.isin(channel_inds, visible_channel_inds)
364-
if not np.any(chan_mask):
365-
continue
366-
367-
sample_inds = sample_inds[chan_mask]
368-
channel_inds = channel_inds[chan_mask]
369-
# Map channel indices to their positions in visible_channel_inds
370-
local_channel_inds = global_to_local_chan_inds[channel_inds]
371-
372-
373-
# Calculate y values using signal values
374-
x = times_chunk[sample_inds]
375-
y = data_curves[local_channel_inds, sample_inds]
376-
377-
# This should both for qt (QTColor) and panel (html color)
378-
color = self.get_unit_color(unit_id)
379-
380-
scatter_x.extend(x)
381-
scatter_y.extend(y)
382-
scatter_colors.extend([color] * len(x))
383-
scatter_unit_ids.extend([str(unit_id)] * len(x))
384-
385-
return times_chunk, data_curves, scatter_x, scatter_y, scatter_colors, scatter_unit_ids
386-
387402
## qt ##
388403
def _qt_make_layout(self):
389404
from .myqt import QT

0 commit comments

Comments
 (0)