22from dataclasses import dataclass , field
33import gc
44import logging
5+ import re
56import os
67from pathlib import Path
78
8-
99import numpy as np
1010import pandas as pd
1111from scipy .interpolate import interp1d
1919from neuropixel import TIP_SIZE_UM , trace_header
2020import spikeglx
2121
22+ import ibldsp .voltage
2223from iblutil .util import Bunch
23- from ibllib .io .extractors .training_wheel import extract_wheel_moves , extract_first_movement_times
2424from iblatlas .atlas import AllenAtlas , BrainRegions
2525from iblatlas import atlas
26+ from ibllib .io .extractors .training_wheel import extract_wheel_moves , extract_first_movement_times
2627from ibllib .pipes import histology
2728from ibllib .pipes .ephys_alignment import EphysAlignment
28- from ibllib .plots import vertical_lines
29+ from ibllib .plots import vertical_lines , Density
2930
3031import brainbox .plot
3132from brainbox .io .spikeglx import Streamer
@@ -916,16 +917,18 @@ def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_
916917 if missing == 'raise' :
917918 raise e
918919
919- def download_spike_sorting (self , ** kwargs ):
920+ def download_spike_sorting (self , objects = None , ** kwargs ):
920921 """
921922 Downloads spikes, clusters and channels
922923 :param spike_sorter: (defaults to 'pykilosort')
923924 :param dataset_types: list of extra dataset types
925+ :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels']
924926 :return:
925927 """
926- for obj in ['spikes' , 'clusters' , 'channels' ]:
928+ objects = ['spikes' , 'clusters' , 'channels' ] if objects is None else objects
929+ for obj in objects :
927930 self .download_spike_sorting_object (obj = obj , ** kwargs )
928- self .spike_sorting_path = self .files ['spikes ' ][0 ].parent
931+ self .spike_sorting_path = self .files ['clusters ' ][0 ].parent
929932
930933 def download_raw_electrophysiology (self , band = 'ap' ):
931934 """
@@ -963,7 +966,7 @@ def raw_electrophysiology(self, stream=True, band='ap', **kwargs):
963966 return Streamer (pid = self .pid , one = self .one , typ = band , ** kwargs )
964967 else :
965968 raw_data_files = self .download_raw_electrophysiology (band = band )
966- cbin_file = next (filter (lambda f : f . name . endswith ( f'. { band } . cbin' ), raw_data_files ), None )
969+ cbin_file = next (filter (lambda f : re . match ( rf".*\. { band } \..* cbin" , f . name ), raw_data_files ), None )
967970 if cbin_file is not None :
968971 return spikeglx .Reader (cbin_file )
969972
@@ -999,7 +1002,7 @@ def load_channels(self, **kwargs):
9991002 self .histology = 'alf'
10001003 return channels
10011004
1002- def load_spike_sorting (self , spike_sorter = 'pykilosort' , ** kwargs ):
1005+ def load_spike_sorting (self , spike_sorter = 'pykilosort' , revision = None , enforce_version = True , good_units = False , ** kwargs ):
10031006 """
10041007 Loads spikes, clusters and channels
10051008
@@ -1013,20 +1016,44 @@ def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs):
10131016 - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
10141017
10151018 :param spike_sorter: (defaults to 'pykilosort')
1016- :param dataset_types: list of extra dataset types
1019+ :param revision: for example "2024-05-06", (defaults to None):
1020+ :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
1021+ :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
1022+ :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
1023+ :param kwargs: additional arguments to be passed to one.api.One.load_object
10171024 :return:
10181025 """
10191026 if len (self .collections ) == 0 :
10201027 return {}, {}, {}
10211028 self .files = {}
10221029 self .spike_sorter = spike_sorter
1023- self .download_spike_sorting (spike_sorter = spike_sorter , ** kwargs )
1024- channels = self .load_channels (spike_sorter = spike_sorter , ** kwargs )
1030+ self .revision = revision
1031+ objects = ['passingSpikes' , 'clusters' , 'channels' ] if good_units else None
1032+ self .download_spike_sorting (spike_sorter = spike_sorter , revision = revision , objects = objects , ** kwargs )
1033+ channels = self .load_channels (spike_sorter = spike_sorter , revision = revision , ** kwargs )
10251034 clusters = self ._load_object (self .files ['clusters' ], wildcards = self .one .wildcards )
1026- spikes = self ._load_object (self .files ['spikes' ], wildcards = self .one .wildcards )
1027-
1035+ if good_units :
1036+ spikes = self ._load_object (self .files ['passingSpikes' ], wildcards = self .one .wildcards )
1037+ else :
1038+ spikes = self ._load_object (self .files ['spikes' ], wildcards = self .one .wildcards )
1039+ if enforce_version :
1040+ self ._assert_version_consistency ()
10281041 return spikes , clusters , channels
10291042
1043+ def _assert_version_consistency (self ):
1044+ """
1045+ Makes sure the state of the spike sorting object matches the files downloaded
1046+ :return: None
1047+ """
1048+ for k in ['spikes' , 'clusters' , 'channels' , 'passingSpikes' ]:
1049+ for fn in self .files .get (k , []):
1050+ if self .spike_sorter :
1051+ assert fn .relative_to (self .session_path ).parts [2 ] == self .spike_sorter , \
1052+ f"You required strict version { self .spike_sorter } , { fn } does not match"
1053+ if self .revision :
1054+ assert fn .relative_to (self .session_path ).parts [3 ] == f"#{ self .revision } #" , \
1055+ f"You required strict revision { self .revision } , { fn } does not match"
1056+
10301057 @staticmethod
10311058 def compute_metrics (spikes , clusters = None ):
10321059 nc = clusters ['channels' ].size if clusters else np .unique (spikes ['clusters' ]).size
@@ -1079,6 +1106,8 @@ def _get_probe_info(self):
10791106 if self ._sync is None :
10801107 timestamps = self .one .load_dataset (
10811108 self .eid , dataset = '_spikeglx_*.timestamps.npy' , collection = f'raw_ephys_data/{ self .pname } ' )
1109+ _ = self .one .load_dataset ( # this is not used here but we want to trigger the download for potential tasks
1110+ self .eid , dataset = '_spikeglx_*.sync.npy' , collection = f'raw_ephys_data/{ self .pname } ' )
10821111 try :
10831112 ap_meta = spikeglx .read_meta_data (self .one .load_dataset (
10841113 self .eid , dataset = '_spikeglx_*.ap.meta' , collection = f'raw_ephys_data/{ self .pname } ' ))
@@ -1116,7 +1145,13 @@ def samples2times(self, values, direction='forward'):
11161145 def pid2ref (self ):
11171146 return f"{ self .one .eid2ref (self .eid , as_dict = False )} _{ self .pname } "
11181147
1119- def raster (self , spikes , channels , save_dir = None , br = None , label = 'raster' , time_series = None , ** kwargs ):
1148+ def _default_plot_title (self , spikes ):
1149+ title = f"{ self .pid2ref } , { self .pid } \n " \
1150+ f"{ spikes ['clusters' ].size :_} spikes, { np .unique (spikes ['clusters' ]).size :_} clusters"
1151+ return title
1152+
1153+ def raster (self , spikes , channels , save_dir = None , br = None , label = 'raster' , time_series = None ,
1154+ drift = None , title = None , ** kwargs ):
11201155 """
11211156 :param spikes: spikes dictionary or Bunch
11221157 :param channels: channels dictionary or Bunch.
@@ -1138,9 +1173,9 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_
11381173 # set default raster plot parameters
11391174 kwargs = {"t_bin" : 0.007 , "d_bin" : 10 , "vmax" : 0.5 }
11401175 brainbox .plot .driftmap (spikes ['times' ], spikes ['depths' ], ax = axs [1 , 0 ], ** kwargs )
1141- title_str = f" { self . pid2ref } , { self . pid } \n " \
1142- f" { spikes [ 'clusters' ]. size :_ } spikes, { np . unique (spikes [ 'clusters' ]). size :_ } clusters"
1143- axs [0 , 0 ].title .set_text (title_str )
1176+ if title is None :
1177+ title = self . _default_plot_title (spikes )
1178+ axs [0 , 0 ].title .set_text (title )
11441179 for k , ts in time_series .items ():
11451180 vertical_lines (ts , ymin = 0 , ymax = 3800 , ax = axs [1 , 0 ])
11461181 if 'atlas_id' in channels :
@@ -1150,10 +1185,55 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_
11501185 axs [1 , 0 ].set_xlim (spikes ['times' ][0 ], spikes ['times' ][- 1 ])
11511186 fig .tight_layout ()
11521187
1153- self .download_spike_sorting_object ('drift' , self .spike_sorter , missing = 'ignore' )
1154- if 'drift' in self .files :
1155- drift = self ._load_object (self .files ['drift' ], wildcards = self .one .wildcards )
1188+ if drift is None :
1189+ self .download_spike_sorting_object ('drift' , self .spike_sorter , missing = 'ignore' )
1190+ if 'drift' in self .files :
1191+ drift = self ._load_object (self .files ['drift' ], wildcards = self .one .wildcards )
1192+ if isinstance (drift , dict ):
11561193 axs [0 , 0 ].plot (drift ['times' ], drift ['um' ], 'k' , alpha = .5 )
1194+ axs [0 , 0 ].set (ylim = [- 15 , 15 ])
1195+
1196+ if save_dir is not None :
1197+ png_file = save_dir .joinpath (f"{ self .pid } _{ self .pid2ref } _{ label } .png" ) if Path (save_dir ).is_dir () else Path (save_dir )
1198+ fig .savefig (png_file )
1199+ plt .close (fig )
1200+ gc .collect ()
1201+ else :
1202+ return fig , axs
1203+
1204+ def plot_rawdata_snippet (self , sr , spikes , clusters , t0 ,
1205+ channels = None ,
1206+ br : BrainRegions = None ,
1207+ save_dir = None ,
1208+ label = 'raster' ,
1209+ gain = - 93 ,
1210+ title = None ):
1211+
1212+ # compute the raw data offset and destripe, we take 400ms around t0
1213+ first_sample , last_sample = (int ((t0 - 0.2 ) * sr .fs ), int ((t0 + 0.2 ) * sr .fs ))
1214+ raw = sr [first_sample :last_sample , :- sr .nsync ].T
1215+ channel_labels = channels ['labels' ] if (channels is not None ) and ('labels' in channels ) else True
1216+ destriped = ibldsp .voltage .destripe (raw , sr .fs , channel_labels = channel_labels )
1217+ # filter out the spikes according to good/bad clusters and to the time slice
1218+ spike_sel = slice (* np .searchsorted (spikes ['samples' ], [first_sample , last_sample ]))
1219+ ss = spikes ['samples' ][spike_sel ]
1220+ sc = clusters ['channels' ][spikes ['clusters' ][spike_sel ]]
1221+ sok = clusters ['label' ][spikes ['clusters' ][spike_sel ]] == 1
1222+ if title is None :
1223+ title = self ._default_plot_title (spikes )
1224+ # display the raw data snippet with spikes overlaid
1225+ fig , axs = plt .subplots (1 , 2 , gridspec_kw = {'width_ratios' : [.95 , .05 ]}, figsize = (16 , 9 ), sharex = 'col' )
1226+ Density (destriped , fs = sr .fs , taxis = 1 , gain = gain , ax = axs [0 ], t0 = t0 - 0.2 , unit = 's' )
1227+ axs [0 ].scatter (ss [sok ] / sr .fs , sc [sok ], color = "green" , alpha = 0.5 )
1228+ axs [0 ].scatter (ss [~ sok ] / sr .fs , sc [~ sok ], color = "red" , alpha = 0.5 )
1229+ axs [0 ].set (title = title , xlim = [t0 - 0.035 , t0 + 0.035 ])
1230+ # adds the channel locations if available
1231+ if (channels is not None ) and ('atlas_id' in channels ):
1232+ br = br or BrainRegions ()
1233+ plot_brain_regions (channels ['atlas_id' ], channel_depths = channels ['axial_um' ],
1234+ brain_regions = br , display = True , ax = axs [1 ], title = self .histology )
1235+ axs [1 ].get_yaxis ().set_visible (False )
1236+ fig .tight_layout ()
11571237
11581238 if save_dir is not None :
11591239 png_file = save_dir .joinpath (f"{ self .pid } _{ self .pid2ref } _{ label } .png" ) if Path (save_dir ).is_dir () else Path (save_dir )
0 commit comments