1616import re
1717import csv
1818import ast
19+ import warnings
1920
2021
2122class PhyRawIO (BaseRawIO ):
@@ -35,9 +36,11 @@ class PhyRawIO(BaseRawIO):
3536 extensions = []
3637 rawmode = 'one-dir'
3738
38- def __init__ (self , dirname = '' ):
39+ def __init__ (self , dirname = '' , load_amplitudes = False , load_pcs = False ):
3940 BaseRawIO .__init__ (self )
4041 self .dirname = dirname
42+ self .load_pcs = load_pcs
43+ self .load_amplitudes = load_amplitudes
4144
4245 def _source_name (self ):
4346 return self .dirname
@@ -53,16 +56,24 @@ def _parse_header(self):
5356 else :
5457 self ._spike_clusters = self ._spike_templates
5558
56- # TODO: Add this when array_annotations are ready
57- # if (phy_folder / 'amplitudes.npy').is_file():
58- # amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))
59- # else:
60- # amplitudes = np.ones(len(spike_times))
61- #
62- # if (phy_folder / 'pc_features.npy').is_file():
63- # pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))
64- # else:
65- # pc_features = None
59+ self ._amplitudes = None
60+ if self .load_amplitudes :
61+ if (phy_folder / 'amplitudes.npy' ).is_file ():
62+ self ._amplitudes = np .squeeze (np .load (phy_folder / 'amplitudes.npy' ))
63+ else :
64+ warnings .warn ('Amplitudes requested but "amplitudes.npy"'
65+ 'not found in the data folder.' )
66+
67+ self ._pc_features = None
68+ self ._pc_feature_ind = None
69+ if self .load_pcs :
70+ if ((phy_folder / 'pc_features.npy' ).is_file ()
71+ and (phy_folder / 'pc_feature_ind.npy' ).is_file ()):
72+ self ._pc_features = np .squeeze (np .load (phy_folder / 'pc_features.npy' ))
73+ self ._pc_feature_ind = np .squeeze (np .load (phy_folder / 'pc_feature_ind.npy' ))
74+ else :
75+ warnings .warn ('PCs requested but "pc_features.npy" and/or'
76+ '"pc_feature_ind.npy" not found in the data folder.' )
6677
6778 # SEE: https://stackoverflow.com/questions/4388626/
6879 # python-safe-eval-string-to-bool-int-float-none-string
@@ -150,6 +161,30 @@ def _parse_header(self):
150161 annotation_dict [property_name ]
151162 break
152163
164+ cluster_mask = (self ._spike_clusters == clust_id ).flatten ()
165+
166+ current_templates = self ._spike_templates [cluster_mask ].flatten ()
167+ unique_templates = np .unique (current_templates )
168+ spiketrain_an ['templates' ] = unique_templates
169+ spiketrain_an ['__array_annotations__' ]['templates' ] = current_templates
170+
171+ if self ._amplitudes is not None :
172+ spiketrain_an ['__array_annotations__' ]['amplitudes' ] = \
173+ self ._amplitudes [cluster_mask ]
174+
175+ if self ._pc_features is not None :
176+ current_pc_features = self ._pc_features [cluster_mask ]
177+ _ , num_pcs , num_pc_channels = current_pc_features .shape
178+ for pc_idx in range (num_pcs ):
179+ for channel_idx in range (num_pc_channels ):
180+ key = 'channel{channel_idx}_pc{pc_idx}' .format (channel_idx = channel_idx ,
181+ pc_idx = pc_idx )
182+ spiketrain_an ['__array_annotations__' ][key ] = \
183+ current_pc_features [:, pc_idx , channel_idx ]
184+
185+ if self ._pc_feature_ind is not None :
186+ spiketrain_an ['pc_feature_ind' ] = self ._pc_feature_ind [unique_templates ]
187+
153188 def _segment_t_start (self , block_index , seg_index ):
154189 assert block_index == 0
155190 return self ._t_start
0 commit comments