33import numpy as np
44import one .alf as alf
55from brainbox .processing import bincount2D
6+ from brainbox .io .spikeglx import stream
67from brainbox .population .decode import xcorr
78from brainbox .task import passive
9+ from ibllib .dsp import voltage
810import scipy
911from PyQt5 import QtGui
1012
@@ -23,8 +25,9 @@ def __init__(self, probe_path, ephys_path, alf_path, shank_idx):
2325 self .ephys_path = ephys_path
2426 self .alf_path = alf_path
2527
26- self .chn_coords_all = np .load (Path (self .probe_path , 'channels.localCoordinates.npy' ))
27- self .chn_ind_all = np .load (Path (self .probe_path , 'channels.rawInd.npy' ))
28+ channels = alf .io .load_object (self .probe_path , 'channels' )
29+ self .chn_coords_all = channels ['localCoordinates' ]
30+ self .chn_ind_all = channels ['rawInd' ].astype (int )
2831
2932 self .chn_min = np .min (self .chn_coords_all [:, 1 ])
3033 self .chn_max = np .max (self .chn_coords_all [:, 1 ])
@@ -71,7 +74,8 @@ def __init__(self, probe_path, ephys_path, alf_path, shank_idx):
7174 self .cluster_data_status = True
7275 self .compute_timescales ()
7376
74- except Exception :
77+ except Exception as err :
78+ print (err )
7579 print ('cluster data was not found, some plots will not display' )
7680 self .cluster_data_status = False
7781
@@ -295,9 +299,11 @@ def get_fr_img(self):
295299 else :
296300 T_BIN = 0.05
297301 D_BIN = 5
302+ chn_min = np .min (np .r_ [self .chn_min , self .spikes ['depths' ][self .spike_idx ][self .kp_idx ]])
303+ chn_max = np .max (np .r_ [self .chn_max , self .spikes ['depths' ][self .spike_idx ][self .kp_idx ]])
298304 n , times , depths = bincount2D (self .spikes ['times' ][self .spike_idx ][self .kp_idx ],
299305 self .spikes ['depths' ][self .spike_idx ][self .kp_idx ],
300- T_BIN , D_BIN , ylim = [self . chn_min , self . chn_max ])
306+ T_BIN , D_BIN , ylim = [chn_min , chn_max ])
301307 img = n .T / T_BIN
302308 xscale = (times [- 1 ] - times [0 ]) / img .shape [0 ]
303309 yscale = (depths [- 1 ] - depths [0 ]) / img .shape [1 ]
@@ -323,14 +329,16 @@ def get_fr_amp_data_line(self):
323329 else :
324330 T_BIN = np .max (self .spikes ['times' ])
325331 D_BIN = 10
332+ chn_min = np .min (np .r_ [self .chn_min , self .spikes ['depths' ][self .spike_idx ][self .kp_idx ]])
333+ chn_max = np .max (np .r_ [self .chn_max , self .spikes ['depths' ][self .spike_idx ][self .kp_idx ]])
326334 nspikes , times , depths = bincount2D (self .spikes ['times' ][self .spike_idx ][self .kp_idx ],
327335 self .spikes ['depths' ][self .spike_idx ][self .kp_idx ],
328336 T_BIN , D_BIN ,
329- ylim = [self . chn_min , self . chn_max ])
337+ ylim = [chn_min , chn_max ])
330338
331339 amp , times , depths = bincount2D (self .spikes ['amps' ][self .spike_idx ][self .kp_idx ],
332340 self .spikes ['depths' ][self .spike_idx ][self .kp_idx ],
333- T_BIN , D_BIN , ylim = [self . chn_min , self . chn_max ],
341+ T_BIN , D_BIN , ylim = [chn_min , chn_max ],
334342 weights = self .spikes ['amps' ][self .spike_idx ]
335343 [self .kp_idx ])
336344 mean_fr = nspikes [:, 0 ] / T_BIN
@@ -363,9 +371,11 @@ def get_correlation_data_img(self):
363371 else :
364372 T_BIN = 0.05
365373 D_BIN = 40
374+ chn_min = np .min (np .r_ [self .chn_min , self .spikes ['depths' ][self .spike_idx ][self .kp_idx ]])
375+ chn_max = np .max (np .r_ [self .chn_max , self .spikes ['depths' ][self .spike_idx ][self .kp_idx ]])
366376 R , times , depths = bincount2D (self .spikes ['times' ][self .spike_idx ][self .kp_idx ],
367377 self .spikes ['depths' ][self .spike_idx ][self .kp_idx ],
368- T_BIN , D_BIN , ylim = [self . chn_min , self . chn_max ])
378+ T_BIN , D_BIN , ylim = [chn_min , chn_max ])
369379 corr = np .corrcoef (R )
370380 corr [np .isnan (corr )] = 0
371381 scale = (np .max (depths ) - np .min (depths )) / corr .shape [0 ]
@@ -464,6 +474,38 @@ def median_subtract(a):
464474
465475 return data_img , data_probe
466476
477+ # only for IBL sorry
478+ def get_raw_data_image (self , pid , t0 = (1000 , 2000 , 3000 ), one = None ):
479+
480+ def gain2level (gain ):
481+ return 10 ** (gain / 20 ) * 4 * np .array ([- 1 , 1 ])
482+ data_img = dict ()
483+ for t in t0 :
484+
485+ sr , t = stream (pid , t , one = one )
486+ raw = sr [:, :- sr .nsync ].T
487+ channel_labels , channel_features = voltage .detect_bad_channels (raw , sr .fs )
488+ raw = voltage .destripe (raw , fs = sr .fs , channel_labels = channel_labels )
489+ raw_image = raw [:, int ((450 / 1e3 ) * sr .fs ):int ((500 / 1e3 ) * sr .fs )].T
490+ x_range = np .array ([0 , raw_image .shape [0 ] - 1 ]) / sr .fs * 1e3
491+ levels = gain2level (- 90 )
492+ xscale = (x_range [1 ] - x_range [0 ]) / raw_image .shape [0 ]
493+ yscale = (self .chn_max - self .chn_min ) / raw_image .shape [1 ]
494+
495+ data_raw = {
496+ 'img' : raw_image ,
497+ 'scale' : np .array ([xscale , yscale ]),
498+ 'levels' : levels ,
499+ 'offset' : np .array ([0 , 0 ]),
500+ 'cmap' : 'bone' ,
501+ 'xrange' : x_range ,
502+ 'xaxis' : 'Time (ms)' ,
503+ 'title' : 'Power (uV)'
504+ }
505+ data_img [f'Raw data t={ t } ' ] = data_raw
506+
507+ return data_img
508+
467509 def get_lfp_spectrum_data (self ):
468510 freq_bands = np .vstack (([0 , 4 ], [4 , 10 ], [10 , 30 ], [30 , 80 ], [80 , 200 ]))
469511 data_probe = {}
0 commit comments