11import numpy as np
22import time
3+ from contextlib import nullcontext
34
45from .view_base import ViewBase
56
1314# *
1415
1516class MixinViewTrace :
17+
18+ MAX_RETRIEVE_TIME_FOR_BUSY_CURSOR = 0.5 # seconds
19+
20+ def get_data_in_chunk (self , t1 , t2 , segment_index ):
21+ with self .trace_context ():
22+ if not self ._retrieve_traces_time_checked :
23+ t_traces_start = time .perf_counter ()
24+ t_start = 0.0
25+ sr = self .controller .sampling_frequency
26+
27+ ind1 = max (0 , int ((t1 - t_start ) * sr ))
28+ ind2 = min (self .controller .get_num_samples (segment_index ), int ((t2 - t_start ) * sr ))
29+
30+ traces_chunk = self .controller .get_traces (segment_index = segment_index , start_frame = ind1 , end_frame = ind2 )
31+
32+ sl = self .controller .segment_slices [segment_index ]
33+ spikes_seg = self .controller .spikes [sl ]
34+ i1 , i2 = np .searchsorted (spikes_seg ["sample_index" ], [ind1 , ind2 ])
35+ spikes_chunk = spikes_seg [i1 :i2 ].copy ()
36+ spikes_chunk ["sample_index" ] -= ind1
37+
38+ visible_channel_inds = self .get_visible_channel_inds ()
39+
40+ data_curves = traces_chunk [:, visible_channel_inds ].T .copy ()
41+
42+ if data_curves .dtype != "float32" :
43+ data_curves = data_curves .astype ("float32" )
44+
45+ if self .factor is not None :
46+ n = visible_channel_inds .size
47+ gains = np .ones (n , dtype = float ) * 1.0 / (self .factor * max (self .mad [visible_channel_inds ]))
48+ offsets = np .arange (n )[::- 1 ] - self .med [visible_channel_inds ] * gains
49+
50+ data_curves *= gains [:, None ]
51+ data_curves += offsets [:, None ]
52+
53+ times_chunk = np .arange (traces_chunk .shape [0 ], dtype = 'float64' ) / self .controller .sampling_frequency + max (t1 , 0 )
54+
55+ scatter_x = []
56+ scatter_y = []
57+ scatter_colors = []
58+ scatter_unit_ids = []
59+
60+ global_to_local_chan_inds = np .zeros (self .controller .channel_ids .size , dtype = 'int64' )
61+ global_to_local_chan_inds [visible_channel_inds ] = np .arange (visible_channel_inds .size , dtype = 'int64' )
62+
63+ for unit_index , unit_id in self .controller .iter_visible_units ():
64+
65+ inds = np .flatnonzero (spikes_chunk ["unit_index" ] == unit_index )
66+ if inds .size == 0 :
67+ continue
68+
69+ # Get spikes for this unit
70+ unit_spikes = spikes_chunk [inds ]
71+ channel_inds = unit_spikes ["channel_index" ]
72+ sample_inds = unit_spikes ["sample_index" ]
73+
74+ chan_mask = np .isin (channel_inds , visible_channel_inds )
75+ if not np .any (chan_mask ):
76+ continue
77+
78+ sample_inds = sample_inds [chan_mask ]
79+ channel_inds = channel_inds [chan_mask ]
80+ # Map channel indices to their positions in visible_channel_inds
81+ local_channel_inds = global_to_local_chan_inds [channel_inds ]
82+
83+
84+ # Calculate y values using signal values
85+ x = times_chunk [sample_inds ]
86+ y = data_curves [local_channel_inds , sample_inds ]
87+
88+ # This should both for qt (QTColor) and panel (html color)
89+ color = self .get_unit_color (unit_id )
90+
91+ scatter_x .extend (x )
92+ scatter_y .extend (y )
93+ scatter_colors .extend ([color ] * len (x ))
94+ scatter_unit_ids .extend ([str (unit_id )] * len (x ))
95+ if not self ._retrieve_traces_time_checked :
96+ t_traces_end = time .perf_counter ()
97+ elapsed = t_traces_end - t_traces_start
98+ if elapsed > self .MAX_RETRIEVE_TIME_FOR_BUSY_CURSOR :
99+ print (f"Trace retrieval took { elapsed :.3f} seconds. Enabling busy cursor." )
100+ self .trace_context = self .busy_cursor
101+ self ._retrieve_traces_time_checked = True
102+
103+ return times_chunk , data_curves , scatter_x , scatter_y , scatter_colors , scatter_unit_ids
104+
16105 ## Qt ##
17106 def _qt_create_toolbar (self ):
18107 from .myqt import QT
@@ -284,6 +373,8 @@ def __init__(self, controller=None, parent=None, backend="qt"):
284373 self .factor = 15.0
285374 self .xsize = 0.5
286375 self ._block_auto_refresh_and_notify = False
376+ self ._retrieve_traces_time_checked = False
377+ self .trace_context = nullcontext
287378
288379 ViewBase .__init__ (self , controller = controller , parent = parent , backend = backend )
289380 MixinViewTrace .__init__ (self )
@@ -308,82 +399,6 @@ def apply_gain_zoom(self, factor_ratio):
308399 self .factor *= factor_ratio
309400 self .refresh ()
310401
311- def get_data_in_chunk (self , t1 , t2 , segment_index ):
312- t_start = 0.0
313- sr = self .controller .sampling_frequency
314-
315- ind1 = max (0 , int ((t1 - t_start ) * sr ))
316- ind2 = min (self .controller .get_num_samples (segment_index ), int ((t2 - t_start ) * sr ))
317-
318- traces_chunk = self .controller .get_traces (segment_index = segment_index , start_frame = ind1 , end_frame = ind2 )
319-
320- sl = self .controller .segment_slices [segment_index ]
321- spikes_seg = self .controller .spikes [sl ]
322- i1 , i2 = np .searchsorted (spikes_seg ["sample_index" ], [ind1 , ind2 ])
323- spikes_chunk = spikes_seg [i1 :i2 ].copy ()
324- spikes_chunk ["sample_index" ] -= ind1
325-
326- visible_channel_inds = self .get_visible_channel_inds ()
327-
328- data_curves = traces_chunk [:, visible_channel_inds ].T .copy ()
329-
330- if data_curves .dtype != "float32" :
331- data_curves = data_curves .astype ("float32" )
332-
333- n = visible_channel_inds .size
334- gains = np .ones (n , dtype = float ) * 1.0 / (self .factor * max (self .mad [visible_channel_inds ]))
335- offsets = np .arange (n )[::- 1 ] - self .med [visible_channel_inds ] * gains
336-
337- data_curves *= gains [:, None ]
338- data_curves += offsets [:, None ]
339-
340- times_chunk = np .arange (traces_chunk .shape [0 ], dtype = 'float64' )/ self .controller .sampling_frequency + max (t1 , 0 )
341-
342- scatter_x = []
343- scatter_y = []
344- scatter_colors = []
345- scatter_unit_ids = []
346-
347- global_to_local_chan_inds = np .zeros (self .controller .channel_ids .size , dtype = 'int64' )
348- global_to_local_chan_inds [visible_channel_inds ] = np .arange (visible_channel_inds .size , dtype = 'int64' )
349-
350-
351- for unit_index , unit_id in self .controller .iter_visible_units ():
352-
353- inds = np .flatnonzero (spikes_chunk ["unit_index" ] == unit_index )
354- if inds .size == 0 :
355- continue
356-
357- # Get spikes for this unit
358- unit_spikes = spikes_chunk [inds ]
359- channel_inds = unit_spikes ["channel_index" ]
360- sample_inds = unit_spikes ["sample_index" ]
361-
362-
363- chan_mask = np .isin (channel_inds , visible_channel_inds )
364- if not np .any (chan_mask ):
365- continue
366-
367- sample_inds = sample_inds [chan_mask ]
368- channel_inds = channel_inds [chan_mask ]
369- # Map channel indices to their positions in visible_channel_inds
370- local_channel_inds = global_to_local_chan_inds [channel_inds ]
371-
372-
373- # Calculate y values using signal values
374- x = times_chunk [sample_inds ]
375- y = data_curves [local_channel_inds , sample_inds ]
376-
377- # This should both for qt (QTColor) and panel (html color)
378- color = self .get_unit_color (unit_id )
379-
380- scatter_x .extend (x )
381- scatter_y .extend (y )
382- scatter_colors .extend ([color ] * len (x ))
383- scatter_unit_ids .extend ([str (unit_id )] * len (x ))
384-
385- return times_chunk , data_curves , scatter_x , scatter_y , scatter_colors , scatter_unit_ids
386-
387402 ## qt ##
388403 def _qt_make_layout (self ):
389404 from .myqt import QT
0 commit comments