Skip to content

Commit dc093ef

Browse files
authored
Merge pull request #43 from int-brain-lab/develop
Develop
2 parents ed73643 + cda0af5 commit dc093ef

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+6526
-1031
lines changed

atlaselectrophysiology/ColorBar.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import pyqtgraph as pg
33
import matplotlib
44
import numpy as np
5+
from pyqtgraph.functions import makeARGB
56

67

78
class ColorBar(pg.GraphicsWidget):
89

9-
def __init__(self, cmap_name, cbin=256, parent=None):
10+
def __init__(self, cmap_name, cbin=256, parent=None, data=None):
1011
pg.GraphicsWidget.__init__(self)
1112

1213
# Create colour map from matplotlib colourmap name
@@ -23,6 +24,13 @@ def __init__(self, cmap_name, cbin=256, parent=None):
2324
self.lut = self.map.getLookupTable()
2425
self.grad = self.map.getGradient()
2526

27+
def getBrush(self, data, levels=None):
28+
if levels is None:
29+
levels = [np.min(data), np.max(data)]
30+
brush_rgb, _ = makeARGB(data[:, np.newaxis], levels=levels, lut=self.lut, useRGBA=True)
31+
brush = [QtGui.QColor(*col) for col in np.squeeze(brush_rgb)]
32+
return brush
33+
2634
def getColourMap(self):
2735
return self.lut
2836

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
from easyqc.gui import viewseis
2+
from ibllib.dsp import voltage
3+
from ibllib.ephys import neuropixel
4+
from viewspikes.data import stream, get_ks2
5+
from viewspikes.plots import overlay_spikes
6+
import scipy
7+
from PyQt5 import QtCore, QtGui
8+
import numpy as np
9+
import pyqtgraph as pg
10+
import qt
11+
from one.api import ONE
12+
from iblutil.util import Bunch
13+
14+
import atlaselectrophysiology.ephys_atlas_gui as alignment_window
15+
import data_exploration_gui.gui_main as trial_window
16+
17+
18+
class AlignmentWindow(alignment_window.MainWindow):
19+
def __init__(self, probe_id=None, one=None, histology=False, spike_collection=None):
20+
21+
self.ap = None # spikeglx.Reader for ap band
22+
self.lf = None # spikeglx.Reader for lf band
23+
self.line_x = None
24+
self.trial_curve = None
25+
self.time_plot = None
26+
self.trial_gui = None
27+
self.clicked = None
28+
self.eqc = {} # handles for viewdata windows
29+
30+
super(AlignmentWindow, self).__init__(probe_id=probe_id, one=one, histology=histology,
31+
spike_collection=spike_collection)
32+
# remove the lines from the plots
33+
self.remove_lines_points()
34+
self.lines_features = []
35+
self.lines_tracks = []
36+
self.points = []
37+
38+
def on_mouse_double_clicked(self, event):
39+
if not self.offline:
40+
if event.double() and event.modifiers() and QtCore.Qt.ShiftModifier:
41+
pos = self.data_plot.mapFromScene(event.scenePos())
42+
if self.line_x is not None:
43+
self.fig_img.removeItem(self.line_x)
44+
45+
self.line_x = pg.InfiniteLine(pos=pos.x() * self.x_scale, angle=90,
46+
pen=self.kpen_dot, movable=False)
47+
self.line_x.setZValue(100)
48+
self.fig_img.addItem(self.line_x)
49+
self.stream_ap(pos.x() * self.x_scale)
50+
self.stream_lf(pos.x() * self.x_scale)
51+
52+
return
53+
54+
def plot_image(self, data):
55+
super().plot_image(data)
56+
self.remove_trial_curve(data['xaxis'])
57+
self.remove_line_x(data['xaxis'])
58+
if 'Time' in data['xaxis']:
59+
self.time_plot = True
60+
else:
61+
self.time_plot = False
62+
63+
def plot_scatter(self, data):
64+
super().plot_scatter(data)
65+
self.remove_trial_curve(data['xaxis'])
66+
self.remove_line_x(data['xaxis'])
67+
if 'Time' in data['xaxis']:
68+
self.time_plot = True
69+
else:
70+
self.time_plot = False
71+
72+
def add_trials(self, trial_key='feedback_times'):
73+
self.selected_trials = self.plotdata.trials[trial_key]
74+
x, y = self.vertical_lines(self.selected_trials, 0, 3840)
75+
self.trial_curve = pg.PlotCurveItem()
76+
self.trial_curve.setData(x=x, y=y, pen=self.rpen_dot, connect='finite')
77+
self.trial_curve.setClickable(True)
78+
self.fig_img.addItem(self.trial_curve)
79+
self.fig_img.scene().sigMouseClicked.connect(self.on_mouse_clicked)
80+
self.trial_curve.sigClicked.connect(self.trial_line_clicked)
81+
82+
def remove_trials(self):
83+
self.fig_img.removeItem(self.trial_curve)
84+
85+
def vertical_lines(self, x, ymin, ymax):
86+
87+
x = np.tile(x, (3, 1))
88+
x[2, :] = np.nan
89+
y = np.zeros_like(x)
90+
y[0, :] = ymin
91+
y[1, :] = ymax
92+
y[2, :] = np.nan
93+
94+
return x.T.flatten(), y.T.flatten()
95+
96+
def trial_line_clicked(self, ev):
97+
self.clicked = ev
98+
99+
def on_mouse_clicked(self, event):
100+
if self.trial_gui is not None:
101+
if not event.double() and type(self.clicked) == pg.PlotCurveItem:
102+
self.pos = self.data_plot.mapFromScene(event.scenePos())
103+
x = self.pos.x() * self.x_scale
104+
trial_id = np.argmin(np.abs(self.selected_trials - x))
105+
print(trial_id)
106+
107+
idx = np.where(self.trial_gui.data.y == 10 * trial_id)
108+
self.trial_scat = pg.ScatterPlotItem()
109+
self.trial_gui.plots.fig4_raster.fig.addItem(self.trial_scat)
110+
self.trial_scat.setData(self.trial_gui.data.x[idx], self.trial_gui.data.y[idx],
111+
brush='r', size=5)
112+
113+
self.clicked = None
114+
115+
def stream_lf(self, t):
116+
if self.lf is not None:
117+
self.lf.close()
118+
119+
self.lf, dsets, t0 = stream(
120+
self.loaddata.probe_id, t=t, one=self.loaddata.one, cache=True, typ='lf')
121+
sos = scipy.signal.butter(3, 5 / self.lf.fs / 2, btype='highpass', output='sos')
122+
butt = scipy.signal.sosfiltfilt(sos, self.lf[:, :-1].T)
123+
h = neuropixel.trace_header()
124+
self.eqc['raw_lf'] = viewseis(
125+
butt.T, si=1 / self.lf.fs, h=h, t0=t0, title='raw_lf', taxis=0)
126+
self.lf.close()
127+
128+
def stream_ap(self, t):
129+
if self.ap is not None:
130+
self.ap.close()
131+
132+
self.ap, dsets, t0 = stream(
133+
self.loaddata.probe_id, t=t, one=self.loaddata.one, cache=True)
134+
raw = self.ap[:, :-1].T
135+
h = neuropixel.trace_header()
136+
sos = scipy.signal.butter(3, 300 / self.ap.fs / 2, btype='highpass', output='sos')
137+
butt = scipy.signal.sosfiltfilt(sos, raw)
138+
destripe = voltage.destripe(raw, fs=self.ap.fs)
139+
ks2 = get_ks2(raw, dsets, self.loaddata.one)
140+
self.eqc['butterworth'] = viewseis(butt.T, si=1 / self.ap.fs, h=h, t0=t0, title='butt',
141+
taxis=0)
142+
self.eqc['destripe'] = viewseis(destripe.T, si=1 / self.ap.fs, h=h, t0=t0, title='destr',
143+
taxis=0)
144+
self.eqc['ks2'] = viewseis(ks2.T, si=1 / self.ap.fs, h=h, t0=t0, title='ks2', taxis=0)
145+
146+
overlay_spikes(self.eqc['butterworth'], self.plotdata.spikes, self.plotdata.clusters,
147+
self.plotdata.channels)
148+
overlay_spikes(self.eqc['destripe'], self.plotdata.spikes, self.plotdata.clusters,
149+
self.plotdata.channels)
150+
overlay_spikes(self.eqc['ks2'], self.plotdata.spikes, self.plotdata.clusters,
151+
self.plotdata.channels)
152+
self.ap.close()
153+
154+
def remove_line_x(self, xaxis):
155+
"""
156+
If we have any horizontal lines to indicate the time points delete them if the x axis is
157+
not time
158+
:param xaxis:
159+
:return:
160+
"""
161+
if self.line_x is not None:
162+
self.fig_img.removeItem(self.line_x)
163+
if 'Time' in xaxis:
164+
self.fig_img.addItem(self.line_x)
165+
166+
def remove_trial_curve(self, xaxis):
167+
if self.trial_curve is not None:
168+
self.fig_img.removeItem(self.trial_curve)
169+
if 'Time' in xaxis:
170+
self.fig_img.addItem(self.trial_curve)
171+
172+
def remove_lines_points(self):
173+
super().remove_lines_points()
174+
self.remove_line_x('la')
175+
self.remove_trial_curve('la')
176+
177+
def add_lines_points(self):
178+
super().add_lines_points()
179+
if self.time_plot:
180+
self.remove_line_x('Time')
181+
self.remove_trial_curve('Time')
182+
183+
def closeEvent(self, event):
184+
"""
185+
Close the spikeglx file when window is closed
186+
"""
187+
super().closeEvent(event)
188+
if self.ap is not None:
189+
self.ap.close()
190+
191+
def complete_button_pressed(self):
192+
QtGui.QMessageBox.information(self, 'Status', ("Not going to upload any results, to do"
193+
" an alignment, launch normally"))
194+
195+
196+
class TrialWindow(trial_window.MainWindow):
197+
def __init__(self):
198+
super(TrialWindow, self).__init__()
199+
self.alignment_gui = None
200+
self.scat = None
201+
202+
def on_scatter_plot_clicked(self, scatter, point):
203+
super().on_scatter_plot_clicked(scatter, point)
204+
self.add_clust_scatter()
205+
206+
def on_cluster_list_clicked(self):
207+
super().on_cluster_list_clicked()
208+
self.add_clust_scatter()
209+
210+
def on_next_cluster_clicked(self):
211+
super().on_next_cluster_clicked()
212+
self.add_clust_scatter()
213+
214+
def on_previous_cluster_clicked(self):
215+
super().on_previous_cluster_clicked()
216+
self.add_clust_scatter()
217+
218+
def add_clust_scatter(self):
219+
if not self.scat:
220+
self.scat = pg.ScatterPlotItem()
221+
self.alignment_gui.fig_img.addItem(self.scat)
222+
223+
self.scat.setData(self.data.spikes.times[self.data.clus_idx],
224+
self.data.spikes.depths[self.data.clus_idx], brush='g', size=5)
225+
226+
227+
def load_extra_data(probe_id, one=None, spike_collection=None):
228+
one = one or ONE()
229+
eid, probe = one.pid2eid(probe_id)
230+
if spike_collection:
231+
collection = f'alf/{probe}/{spike_collection}'
232+
else:
233+
collection = f'alf/{probe}'
234+
235+
_ = one.load_object(eid, obj='spikes', collection=collection,
236+
attribute='samples')
237+
trials = one.load_object(eid, obj='trials')
238+
239+
return trials
240+
241+
242+
def viewer(probe_id=None, one=None, data_explore=False, spike_collection=None):
243+
"""
244+
"""
245+
qt.create_app()
246+
trials = load_extra_data(probe_id, one=one, spike_collection=spike_collection)
247+
av = AlignmentWindow(probe_id=probe_id, one=one, spike_collection=spike_collection)
248+
av.plotdata.trials = trials
249+
av.show()
250+
251+
if data_explore:
252+
data = Bunch()
253+
data['spikes'] = av.plotdata.spikes
254+
data['clusters'] = av.plotdata.clusters
255+
data['trials'] = av.plotdata.trials
256+
bv = TrialWindow()
257+
bv.on_data_given(data)
258+
av.trial_gui = bv
259+
bv.alignment_gui = av
260+
bv.show()
261+
262+
return av

atlaselectrophysiology/compare_alignments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# import modules
2-
from oneibl.one import ONE
2+
from one.api import ONE
33
from ibllib.pipes.ephys_alignment import EphysAlignment
44
import numpy as np
55
import matplotlib.pyplot as plt

atlaselectrophysiology/create_overview_plots.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,25 @@ def load_image(image_name, ax):
2727
gs = fig.add_gridspec(3, 18)
2828
gs.update(wspace=0.025, hspace=0.05)
2929

30+
ignore_img_plots = ['leftGabor', 'rightGabor', 'noiseOn', 'valveOn', 'toneOn']
3031
img_row_order = [0, 0, 0, 0, 0, 0, 1, 1, 1]
3132
img_column_order = [0, 3, 6, 9, 12, 15, 0, 3, 6]
3233
img_idx = [0, 5, 4, 6, 7, 8, 1, 2, 3]
3334
img_files = glob.glob(str(image_folder.joinpath(image_info + 'img_*.png')))
35+
img_files = [img for img in img_files if not any([ig in img for ig in ignore_img_plots])]
3436
img_files_sort = [img_files[idx] for idx in img_idx]
3537

3638
for iF, file in enumerate(img_files_sort):
3739
ax = fig.add_subplot(gs[img_row_order[iF], img_column_order[iF]:img_column_order[iF] + 3])
3840
load_image(Path(file), ax)
3941

42+
ignore_probe_plots = ['RF Map']
4043
probe_row_order = [1, 1, 1, 1, 1, 1, 2, 2, 2]
4144
probe_column_order = [9, 10, 11, 12, 13, 14, 12, 13, 14]
4245
probe_idx = [0, 3, 1, 2, 4, 5, 6]
4346
probe_files = glob.glob(str(image_folder.joinpath(image_info + 'probe_*.png')))
47+
probe_files = [probe for probe in probe_files if not any([pr in probe for pr in
48+
ignore_probe_plots])]
4449
probe_files_sort = [probe_files[idx] for idx in probe_idx]
4550
line_files = glob.glob(str(image_folder.joinpath(image_info + 'line_*.png')))
4651

@@ -75,7 +80,7 @@ def load_image(image_name, ax):
7580
load_image(Path(file), ax)
7681

7782
ax.text(0.5, 0, image_info[:-1], va="center", ha="center", transform=ax.transAxes)
78-
7983
plt.savefig(save_folder.joinpath(image_info + "overview.png"),
8084
bbox_inches='tight', pad_inches=0)
81-
plt.show()
85+
# plt.close()
86+
# plt.show()

0 commit comments

Comments
 (0)