Skip to content

Commit a6e5e78

Browse files
authored
Merge branch 'main' into metrics-view
2 parents 847976b + 8827ef4 commit a6e5e78

File tree

5 files changed

+163
-3
lines changed

5 files changed

+163
-3
lines changed

spikeinterface_gui/controller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
268268
self.segment_slices = {seg_index: slice(seg_limits[seg_index], seg_limits[seg_index + 1]) for seg_index in range(num_seg)}
269269

270270
spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False)
271+
self.final_spike_samples = [segment_spike_vector[-1][0] for segment_spike_vector in spike_vector2]
271272
# this is dict of list because per segment spike_indices[segment_index][unit_id]
272273
spike_indices = spike_vector_to_indices(spike_vector2, unit_ids)
273274
# this is flatten

spikeinterface_gui/layout_presets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_layout_description(preset_name, layout=None):
4242
default_layout = dict(
4343
zone1=['curation', 'spikelist'],
4444
zone2=['unitlist', 'mergelist'],
45-
zone3=['trace', 'tracemap', 'spikeamplitude', 'spikedepth'],
45+
zone3=['trace', 'tracemap', 'spikeamplitude', 'spikedepth', 'spikerate'],
4646
zone4=[],
4747
zone5=['probe'],
4848
zone6=['ndscatter', 'similarity'],

spikeinterface_gui/rateview.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from .view_base import ViewBase
2+
import numpy as np
3+
4+
5+
class SpikeRateView(ViewBase):
6+
_supported_backend = ['qt', 'panel']
7+
_settings = [
8+
{'name': 'bin_s', 'type': 'int', 'value' : 60 },
9+
]
10+
_need_compute = False
11+
12+
def __init__(self, controller=None, parent=None, backend="qt"):
13+
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
14+
15+
def _on_settings_changed(self):
16+
self.refresh()
17+
18+
## Qt ##
19+
20+
def _qt_make_layout(self):
21+
from .myqt import QT
22+
import pyqtgraph as pg
23+
24+
self.layout = QT.QVBoxLayout()
25+
26+
tb = self.qt_widget.view_toolbar
27+
self.combo_seg = QT.QComboBox()
28+
tb.addWidget(self.combo_seg)
29+
self.combo_seg.addItems([ f'Segment {seg_index}' for seg_index in range(self.controller.num_segments) ])
30+
self.combo_seg.currentIndexChanged.connect(self.refresh)
31+
32+
h = QT.QHBoxLayout()
33+
self.layout.addLayout(h)
34+
35+
self.plot = pg.PlotItem(viewBox=None)
36+
self.graphicsview = pg.GraphicsView()
37+
self.graphicsview.setCentralItem(self.plot)
38+
self.layout.addWidget(self.graphicsview)
39+
40+
def _qt_refresh(self):
41+
import pyqtgraph as pg
42+
43+
self.plot.clear()
44+
45+
seg_index = self.combo_seg.currentIndex()
46+
47+
visible_unit_ids = self.controller.get_visible_unit_ids()
48+
49+
sampling_frequency = self.controller.sampling_frequency
50+
51+
total_frames = self.controller.final_spike_samples
52+
bins_s = self.settings['bin_s']
53+
num_bins = total_frames[seg_index] // int(sampling_frequency) // bins_s
54+
55+
for r, unit_id in enumerate(visible_unit_ids):
56+
57+
spike_inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index)
58+
spikes = self.controller.spikes[spike_inds]['sample_index']
59+
60+
count, bins = np.histogram(spikes, bins=num_bins)
61+
62+
color = self.get_unit_color(unit_id)
63+
curve = pg.PlotCurveItem(
64+
(bins[1:]+bins[:-1])/(2*sampling_frequency),
65+
count/bins_s,
66+
pen=pg.mkPen(color, width=2)
67+
)
68+
self.plot.addItem(curve)
69+
70+
# Make lower y-lim 0
71+
self.plot.getViewBox().autoRange()
72+
current_max_y_range = self.plot.getViewBox().viewRange()[1][1]
73+
self.plot.getViewBox().setYRange(0, current_max_y_range)
74+
75+
## panel ##
76+
77+
def _panel_make_layout(self):
78+
import panel as pn
79+
import bokeh.plotting as bpl
80+
from .utils_panel import _bg_color
81+
82+
self.segment_index = 0
83+
self.segment_selector = pn.widgets.Select(
84+
name="",
85+
options=[f"Segment {i}" for i in range(self.controller.num_segments)],
86+
value=f"Segment {self.segment_index}",
87+
)
88+
self.segment_selector.param.watch(self._panel_change_segment, 'value')
89+
90+
self.rate_fig = bpl.figure(
91+
width=250,
92+
height=250,
93+
tools="pan,wheel_zoom,reset",
94+
active_drag="pan",
95+
active_scroll="wheel_zoom",
96+
background_fill_color=_bg_color,
97+
border_fill_color=_bg_color,
98+
outline_line_color="white",
99+
)
100+
self.rate_fig.toolbar.logo = None
101+
self.rate_fig.grid.visible = False
102+
103+
self.layout = pn.Column(
104+
pn.Row(self.segment_selector, sizing_mode="stretch_width"),
105+
pn.Row(self.rate_fig,sizing_mode="stretch_both"),
106+
)
107+
self.is_warning_active = False
108+
109+
def _panel_refresh(self):
110+
import panel as pn
111+
import bokeh.plotting as bpl
112+
from bokeh.layouts import gridplot
113+
from .utils_panel import _bg_color
114+
115+
seg_index = self.segment_index
116+
117+
visible_unit_ids = self.controller.get_visible_unit_ids()
118+
119+
sampling_frequency = self.controller.sampling_frequency
120+
121+
total_frames = self.controller.final_spike_samples
122+
bins_s = self.settings['bin_s']
123+
num_bins = total_frames[seg_index] // int(sampling_frequency) // bins_s
124+
125+
# clear fig
126+
self.rate_fig.renderers = []
127+
128+
for unit_id in visible_unit_ids:
129+
130+
spike_inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index)
131+
spikes = self.controller.spikes[spike_inds]['sample_index']
132+
133+
count, bins = np.histogram(spikes, bins=num_bins)
134+
135+
# Get color from controller
136+
color = self.get_unit_color(unit_id)
137+
138+
line = self.rate_fig.line(
139+
x=(bins[1:]+bins[:-1])/(2*sampling_frequency),
140+
y=count/bins_s,
141+
color=color,
142+
line_width=2,
143+
)
144+
145+
self.rate_fig.y_range.start = 0
146+
147+
def _panel_change_segment(self, event):
148+
self.segment_index = int(self.segment_selector.value.split()[-1])
149+
self.refresh()
150+
151+
152+
153+
SpikeRateView._gui_help_txt = """
154+
# SpikeRateView View
155+
156+
This view shows firing rate for spikes per `bin_s`.
157+
"""

spikeinterface_gui/tests/test_mainwindow_panel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def test_launcher(verbose=True):
100100
if not test_folder.is_dir():
101101
setup_module()
102102

103-
# win = test_mainwindow(start_app=True, verbose=True, curation=True)
103+
win = test_mainwindow(start_app=True, verbose=True, curation=True)
104104

105-
test_launcher(verbose=True)
105+
# test_launcher(verbose=True)
106106

107107
# TO RUN with panel serve:
108108
# win = test_mainwindow(start_app=False, verbose=True, curation=True)

spikeinterface_gui/viewlist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .curationview import CurationView
1616
from .mainsettingsview import MainSettingsView
1717
from .metricsview import MetricsView
18+
from .rateview import SpikeRateView
1819

1920
possible_class_views = dict(
2021
probe = ProbeView, # probe view is first, since it updates channels upon unit changes
@@ -34,4 +35,5 @@
3435
curation = CurationView,
3536
mainsettings=MainSettingsView,
3637
metrics = MetricsView,
38+
spikerate=SpikeRateView,
3739
)

0 commit comments

Comments
 (0)