Skip to content

Commit 84483b2

Browse files
committed
Add metrics view.
Only for Qt at the moment
1 parent 4287736 commit 84483b2

File tree

4 files changed

+149
-9
lines changed

4 files changed

+149
-9
lines changed

spikeinterface_gui/controller.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,9 @@ def has_extension(self, extension_name):
613613
def handle_metrics(self):
614614
return self.metrics is not None
615615

616+
def get_units_table(self):
617+
return self.units_table
618+
616619
def get_all_pcs(self):
617620

618621
if self._pc_projections is None and self.pc_ext is not None:

spikeinterface_gui/metricsview.py

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,24 @@
77
from spikeinterface.postprocessing.unit_locations import possible_localization_methods
88

99

10+
_default_visible_metrics = ("snr", "firing_rate")
11+
1012
class MetricsView(ViewBase):
1113
_supported_backend = ['qt', ]
1214
_settings = [
13-
]
15+
{'name': 'num_bins', 'type': 'int', 'value' : 30 },
16+
]
1417

1518
def __init__(self, controller=None, parent=None, backend="qt"):
16-
self.contact_positions = controller.get_contact_location()
17-
self.probes = controller.get_probegroup().probes
19+
units_table = controller.get_units_table()
20+
self.visible_metrics_dict = dict()
21+
for col in units_table.columns:
22+
if units_table[col].dtype.kind == "f":
23+
self.visible_metrics_dict[col] = col in _default_visible_metrics
24+
1825
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
19-
self._unit_positions = self.controller.unit_positions
26+
27+
2028

2129
## Qt ##
2230
def _qt_make_layout(self):
@@ -26,8 +34,136 @@ def _qt_make_layout(self):
2634

2735
self.layout = QT.QVBoxLayout()
2836

37+
visible_metrics_tree = []
38+
for col, visible in self.visible_metrics_dict.items():
39+
visible_metrics_tree.append(
40+
{'name': str(col), 'type': 'bool', 'value': visible}
41+
)
42+
self.qt_visible_metrics = pg.parametertree.Parameter.create( name='visible columns', type='group', children=visible_metrics_tree)
43+
self.tree_visible_metrics = pg.parametertree.ParameterTree(parent=self.qt_widget)
44+
self.tree_visible_metrics.header().hide()
45+
self.tree_visible_metrics.setParameters(self.qt_visible_metrics, showTop=True)
46+
# self.tree_visible_metrics.setWindowTitle(u'visible columns')
47+
# self.tree_visible_metrics.setWindowFlags(QT.Qt.Window)
48+
self.qt_visible_metrics.sigTreeStateChanged.connect(self._qt_on_visible_metrics_changed)
49+
self.layout.addWidget(self.tree_visible_metrics)
50+
self.tree_visible_metrics.hide()
51+
52+
53+
tb = self.qt_widget.view_toolbar
54+
but = QT.QPushButton('metrics')
55+
but.clicked.connect(self._qt_select_metrics)
56+
tb.addWidget(but)
57+
58+
self.grid = pg.GraphicsLayoutWidget()
59+
self.layout.addWidget(self.grid)
60+
61+
self._qt_creat_grid()
62+
63+
64+
65+
def _qt_creat_grid(self):
66+
import pyqtgraph as pg
67+
from .myqt import QT
68+
69+
visible_metrics = [k for k, v in self.visible_metrics_dict.items() if v]
70+
print(visible_metrics)
71+
self.grid.clear()
72+
n = len(visible_metrics)
73+
if len(visible_metrics) == 0:
74+
return
75+
76+
77+
78+
self.plots = {}
79+
for r in range(n):
80+
for c in range(r, n):
81+
82+
plot = pg.PlotItem()
83+
self.grid.addItem(plot, row=r, col=c)
84+
self.plots[(r, c)] = plot
85+
86+
if r == c:
87+
label_style = {'color': "#7BFF00", 'font-size': '14pt'}
88+
plot.setLabel('bottom', visible_metrics[c], **label_style)
89+
2990
def _qt_refresh(self):
30-
pass
91+
import pyqtgraph as pg
92+
from .myqt import QT
93+
94+
95+
visible_metrics = [k for k, v in self.visible_metrics_dict.items() if v]
96+
n = len(visible_metrics)
97+
98+
units_table = self.controller.get_units_table()
99+
100+
white_brush = QT.QColor('white')
101+
white_brush.setAlpha(200)
102+
103+
104+
for r in range(n):
105+
for c in range(r, n):
106+
col1 = visible_metrics[r]
107+
col2 = visible_metrics[c]
108+
109+
plot = self.plots[(r, c)]
110+
plot.clear()
111+
if c > r:
112+
scatter = pg.ScatterPlotItem(pen=pg.mkPen(None), brush=white_brush, size=11, pxMode = True)
113+
plot.addItem(scatter)
114+
values1 = units_table[col1].values
115+
values2 = units_table[col2].values
116+
117+
scatter.setData(x=values2, y=values1)
118+
119+
visible_unit_ids = self.controller.get_visible_unit_ids()
120+
visible_unit_ids = self.controller.get_visible_unit_indices()
121+
122+
for unit_ind, unit_id in self.controller.iter_visible_units():
123+
color = self.get_unit_color(unit_id)
124+
scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color)
125+
126+
# self.scatter.addPoints(x=scatter_x[unit_id], y=scatter_y[unit_id], pen=pg.mkPen(None), brush=color)
127+
# self.scatter_select.setData(selected_scatter_x, selected_scatter_y)
128+
elif c == r:
129+
values1 = units_table[visible_metrics[r]].values
130+
131+
count, bins = np.histogram(values1, bins=self.settings['num_bins'])
132+
curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=white_brush, pen=white_brush)
133+
plot.addItem(curve)
134+
135+
for unit_ind, unit_id in self.controller.iter_visible_units():
136+
x = values1[unit_ind]
137+
color = self.get_unit_color(unit_id)
138+
line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color)
139+
plot.addItem(line)
140+
141+
142+
# color = self.get_unit_color(unit_id)
143+
# else:
144+
# color = (120,120,120,120)
145+
146+
# curve = pg.PlotCurveItem(self.bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
147+
148+
149+
150+
151+
def _qt_select_metrics(self):
152+
if not self.tree_visible_metrics.isVisible():
153+
self.tree_visible_metrics.show()
154+
else:
155+
self.tree_visible_metrics.hide()
156+
157+
self.layout.addWidget(self.tree_visible_metrics)
158+
159+
def _qt_on_visible_metrics_changed(self):
160+
161+
for col in self.visible_metrics_dict.keys():
162+
# update the internal dict with the qt tree
163+
self.visible_metrics_dict[col] = self.qt_visible_metrics[col]
164+
self._qt_creat_grid()
165+
self.refresh()
166+
31167

32168

33169

spikeinterface_gui/tests/debug_views.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from pathlib import Path
1515

1616

17-
# test_folder = Path(__file__).parent / 'my_dataset_small'
18-
test_folder = Path(__file__).parent / 'my_dataset_big'
17+
test_folder = Path(__file__).parent / 'my_dataset_small'
18+
# test_folder = Path(__file__).parent / 'my_dataset_big'
1919
# test_folder = Path(__file__).parent / 'my_dataset_multiprobe'
2020

2121

@@ -34,8 +34,9 @@ def debug_one_view():
3434
)
3535

3636
# view_class = possible_class_views['unitlist']
37-
view_class = possible_class_views['mainsettings']
37+
# view_class = possible_class_views['mainsettings']
3838
# view_class = possible_class_views['spikeamplitude']
39+
view_class = possible_class_views['metrics']
3940
widget = ViewWidget(view_class)
4041
view = view_class(controller=controller, parent=widget, backend='qt')
4142
widget.set_view(view)

spikeinterface_gui/tests/test_mainwindow_qt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
test_folder = Path(__file__).parent / 'my_dataset_small'
18-
test_folder = Path(__file__).parent / 'my_dataset_big'
18+
# test_folder = Path(__file__).parent / 'my_dataset_big'
1919
# test_folder = Path(__file__).parent / 'my_dataset_multiprobe'
2020

2121
# yep is for testing

0 commit comments

Comments
 (0)