Skip to content

Commit 9ead262

Browse files
authored
Merge pull request #175 from samuelgarcia/metrics-view
Metrics view
2 parents 8827ef4 + a6e5e78 commit 9ead262

File tree

6 files changed

+187
-5
lines changed

6 files changed

+187
-5
lines changed

spikeinterface_gui/controller.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,9 @@ def has_extension(self, extension_name):
614614
def handle_metrics(self):
615615
return self.metrics is not None
616616

617+
def get_units_table(self):
618+
return self.units_table
619+
617620
def get_all_pcs(self):
618621

619622
if self._pc_projections is None and self.pc_ext is not None:

spikeinterface_gui/layout_presets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_layout_description(preset_name, layout=None):
4747
zone5=['probe'],
4848
zone6=['ndscatter', 'similarity'],
4949
zone7=['waveform', 'waveformheatmap', ],
50-
zone8=['correlogram', 'isi', 'mainsettings'],
50+
zone8=['correlogram', 'isi', 'metrics', 'mainsettings'],
5151
)
5252
_presets['default'] = default_layout
5353

spikeinterface_gui/metricsview.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import warnings
2+
import numpy as np
3+
4+
5+
from .view_base import ViewBase
6+
7+
from spikeinterface.postprocessing.unit_locations import possible_localization_methods
8+
9+
10+
_default_visible_metrics = ("snr", "firing_rate")
11+
12+
class MetricsView(ViewBase):
13+
_supported_backend = ['qt', ]
14+
_settings = [
15+
{'name': 'num_bins', 'type': 'int', 'value' : 30 },
16+
]
17+
18+
def __init__(self, controller=None, parent=None, backend="qt"):
19+
units_table = controller.get_units_table()
20+
self.visible_metrics_dict = dict()
21+
for col in units_table.columns:
22+
if units_table[col].dtype.kind == "f":
23+
self.visible_metrics_dict[col] = col in _default_visible_metrics
24+
25+
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
26+
27+
28+
29+
## Qt ##
30+
def _qt_make_layout(self):
31+
from .myqt import QT
32+
import pyqtgraph as pg
33+
from .utils_qt import ViewBoxHandlingDoubleClickToPosition
34+
35+
self.layout = QT.QVBoxLayout()
36+
37+
visible_metrics_tree = []
38+
for col, visible in self.visible_metrics_dict.items():
39+
visible_metrics_tree.append(
40+
{'name': str(col), 'type': 'bool', 'value': visible}
41+
)
42+
self.qt_visible_metrics = pg.parametertree.Parameter.create( name='visible columns', type='group', children=visible_metrics_tree)
43+
self.tree_visible_metrics = pg.parametertree.ParameterTree(parent=self.qt_widget)
44+
self.tree_visible_metrics.header().hide()
45+
self.tree_visible_metrics.setParameters(self.qt_visible_metrics, showTop=True)
46+
# self.tree_visible_metrics.setWindowTitle(u'visible columns')
47+
# self.tree_visible_metrics.setWindowFlags(QT.Qt.Window)
48+
self.qt_visible_metrics.sigTreeStateChanged.connect(self._qt_on_visible_metrics_changed)
49+
self.layout.addWidget(self.tree_visible_metrics)
50+
self.tree_visible_metrics.hide()
51+
52+
53+
tb = self.qt_widget.view_toolbar
54+
but = QT.QPushButton('metrics')
55+
but.clicked.connect(self._qt_select_metrics)
56+
tb.addWidget(but)
57+
58+
self.grid = pg.GraphicsLayoutWidget()
59+
self.layout.addWidget(self.grid)
60+
61+
self._qt_creat_grid()
62+
63+
64+
65+
def _qt_creat_grid(self):
66+
import pyqtgraph as pg
67+
from .myqt import QT
68+
69+
visible_metrics = [k for k, v in self.visible_metrics_dict.items() if v]
70+
self.grid.clear()
71+
n = len(visible_metrics)
72+
if len(visible_metrics) == 0:
73+
return
74+
75+
76+
77+
self.plots = {}
78+
for r in range(n):
79+
for c in range(r, n):
80+
81+
plot = pg.PlotItem()
82+
self.grid.addItem(plot, row=r, col=c)
83+
self.plots[(r, c)] = plot
84+
85+
if r == c:
86+
label_style = {'color': "#7BFF00", 'font-size': '14pt'}
87+
plot.setLabel('bottom', visible_metrics[c], **label_style)
88+
89+
def _qt_refresh(self):
90+
import pyqtgraph as pg
91+
from .myqt import QT
92+
93+
94+
visible_metrics = [k for k, v in self.visible_metrics_dict.items() if v]
95+
n = len(visible_metrics)
96+
97+
units_table = self.controller.get_units_table()
98+
99+
white_brush = QT.QColor('white')
100+
white_brush.setAlpha(200)
101+
102+
103+
for r in range(n):
104+
for c in range(r, n):
105+
col1 = visible_metrics[r]
106+
col2 = visible_metrics[c]
107+
108+
plot = self.plots[(r, c)]
109+
plot.clear()
110+
if c > r:
111+
scatter = pg.ScatterPlotItem(pen=pg.mkPen(None), brush=white_brush, size=11, pxMode = True)
112+
plot.addItem(scatter)
113+
values1 = units_table[col1].values
114+
values2 = units_table[col2].values
115+
116+
scatter.setData(x=values2, y=values1)
117+
118+
visible_unit_ids = self.controller.get_visible_unit_ids()
119+
visible_unit_ids = self.controller.get_visible_unit_indices()
120+
121+
for unit_ind, unit_id in self.controller.iter_visible_units():
122+
color = self.get_unit_color(unit_id)
123+
scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color)
124+
125+
# self.scatter.addPoints(x=scatter_x[unit_id], y=scatter_y[unit_id], pen=pg.mkPen(None), brush=color)
126+
# self.scatter_select.setData(selected_scatter_x, selected_scatter_y)
127+
elif c == r:
128+
values1 = units_table[visible_metrics[r]].values
129+
130+
count, bins = np.histogram(values1, bins=self.settings['num_bins'])
131+
curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=white_brush, pen=white_brush)
132+
plot.addItem(curve)
133+
134+
for unit_ind, unit_id in self.controller.iter_visible_units():
135+
x = values1[unit_ind]
136+
color = self.get_unit_color(unit_id)
137+
line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color)
138+
plot.addItem(line)
139+
140+
141+
# color = self.get_unit_color(unit_id)
142+
# else:
143+
# color = (120,120,120,120)
144+
145+
# curve = pg.PlotCurveItem(self.bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
146+
147+
148+
149+
150+
def _qt_select_metrics(self):
151+
if not self.tree_visible_metrics.isVisible():
152+
self.tree_visible_metrics.show()
153+
else:
154+
self.tree_visible_metrics.hide()
155+
156+
self.layout.addWidget(self.tree_visible_metrics)
157+
158+
def _qt_on_visible_metrics_changed(self):
159+
160+
for col in self.visible_metrics_dict.keys():
161+
# update the internal dict with the qt tree
162+
self.visible_metrics_dict[col] = self.qt_visible_metrics[col]
163+
self._qt_creat_grid()
164+
self.refresh()
165+
166+
167+
168+
169+
170+
## panel ##
171+
def _panel_make_layout(self):
172+
import panel as pn
173+
import bokeh.plotting as bpl
174+
from bokeh.models import ColumnDataSource, HoverTool, Label, PanTool
175+
from bokeh.events import Tap, PanStart, PanEnd
176+
from .utils_panel import CustomCircle, _bg_color

spikeinterface_gui/tests/debug_views.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from pathlib import Path
1515

1616

17-
# test_folder = Path(__file__).parent / 'my_dataset_small'
18-
test_folder = Path(__file__).parent / 'my_dataset_big'
17+
test_folder = Path(__file__).parent / 'my_dataset_small'
18+
# test_folder = Path(__file__).parent / 'my_dataset_big'
1919
# test_folder = Path(__file__).parent / 'my_dataset_multiprobe'
2020

2121

@@ -34,8 +34,9 @@ def debug_one_view():
3434
)
3535

3636
# view_class = possible_class_views['unitlist']
37-
view_class = possible_class_views['mainsettings']
37+
# view_class = possible_class_views['mainsettings']
3838
# view_class = possible_class_views['spikeamplitude']
39+
view_class = possible_class_views['metrics']
3940
widget = ViewWidget(view_class)
4041
view = view_class(controller=controller, parent=widget, backend='qt')
4142
widget.set_view(view)

spikeinterface_gui/tests/test_mainwindow_qt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import sys
1515

1616

17-
test_folder = Path(__file__).parent / 'my_dataset_small'
17+
# test_folder = Path(__file__).parent / 'my_dataset_small'
1818
test_folder = Path(__file__).parent / 'my_dataset_big'
1919
# test_folder = Path(__file__).parent / 'my_dataset_multiprobe'
2020

spikeinterface_gui/viewlist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .tracemapview import TraceMapView
1515
from .curationview import CurationView
1616
from .mainsettingsview import MainSettingsView
17+
from .metricsview import MetricsView
1718
from .rateview import SpikeRateView
1819

1920
possible_class_views = dict(
@@ -33,5 +34,6 @@
3334
tracemap = TraceMapView,
3435
curation = CurationView,
3536
mainsettings=MainSettingsView,
37+
metrics = MetricsView,
3638
spikerate=SpikeRateView,
3739
)

0 commit comments

Comments
 (0)