1616import re
1717import csv
1818import ast
19+ import warnings
1920
2021
2122class PhyRawIO (BaseRawIO ):
@@ -35,9 +36,10 @@ class PhyRawIO(BaseRawIO):
3536 extensions = []
3637 rawmode = 'one-dir'
3738
38- def __init__ (self , dirname = '' ):
39+ def __init__ (self , dirname = '' , load_pcs = False ):
3940 BaseRawIO .__init__ (self )
4041 self .dirname = dirname
42+ self .load_pcs = load_pcs
4143
4244 def _source_name (self ):
4345 return self .dirname
@@ -53,16 +55,21 @@ def _parse_header(self):
5355 else :
5456 self ._spike_clusters = self ._spike_templates
5557
56- # TODO: Add this when array_annotations are ready
57- # if (phy_folder / 'amplitudes.npy').is_file():
58- # amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))
59- # else:
60- # amplitudes = np.ones(len(spike_times))
61- #
62- # if (phy_folder / 'pc_features.npy').is_file():
63- # pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))
64- # else:
65- # pc_features = None
58+ if (phy_folder / 'amplitudes.npy' ).is_file ():
59+ self ._amplitudes = np .squeeze (np .load (phy_folder / 'amplitudes.npy' ))
60+ else :
61+ self ._amplitudes = None
62+
63+ self ._pc_features = None
64+ self ._pc_feature_ind = None
65+ if self .load_pcs :
66+ if ((phy_folder / 'pc_features.npy' ).is_file ()
67+ and (phy_folder / 'pc_feature_ind.npy' ).is_file ()):
68+ self ._pc_features = np .squeeze (np .load (phy_folder / 'pc_features.npy' ))
69+ self ._pc_feature_ind = np .squeeze (np .load (phy_folder / 'pc_feature_ind.npy' ))
70+ else :
71+ warnings .warn ('PCs requested but "pc_features.npy" and/or'
72+ '"pc_feature_ind.npy" not found in the data folder.' )
6673
6774 # SEE: https://stackoverflow.com/questions/4388626/
6875 # python-safe-eval-string-to-bool-int-float-none-string
@@ -150,6 +157,25 @@ def _parse_header(self):
150157 annotation_dict [property_name ]
151158 break
152159
160+ cluster_mask = (self ._spike_clusters == clust_id ).flatten ()
161+
162+ if self ._amplitudes is not None :
163+ spiketrain_an ['__array_annotations__' ]['amplitudes' ] = \
164+ self ._amplitudes [cluster_mask ]
165+
166+ if self ._pc_features is not None :
167+ current_pc_features = self ._pc_features [cluster_mask ]
168+ _ , num_pcs , num_pc_channels = current_pc_features .shape
169+ for pc_idx in range (num_pcs ):
170+ for channel_idx in range (num_pc_channels ):
171+ key = 'channel{channel_idx}_pc{pc_idx}' .format (channel_idx = channel_idx ,
172+ pc_idx = pc_idx )
173+ spiketrain_an ['__array_annotations__' ][key ] = \
174+ current_pc_features [:, pc_idx , channel_idx ]
175+
176+ if self ._pc_feature_ind is not None :
177+ spiketrain_an ['pc_feature_ind' ] = self ._pc_feature_ind [index ]
178+
153179 def _segment_t_start (self , block_index , seg_index ):
154180 assert block_index == 0
155181 return self ._t_start
0 commit comments