Skip to content

Commit 26446d2

Browse files
Merge pull request #1158 from INM-6/enh/phyio_read_amplitudes_and_pcs
Array-annotate Amplitudes and PCs in PhyIO
2 parents 1af517c + 378b08e commit 26446d2

File tree

2 files changed

+51
-13
lines changed

2 files changed

+51
-13
lines changed

neo/io/phyio.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class PhyIO(PhyRawIO, BaseFromRaw):
77
description = "Phy IO"
88
mode = 'dir'
99

10-
def __init__(self, dirname):
11-
PhyRawIO.__init__(self, dirname=dirname)
10+
def __init__(self, dirname, load_amplitudes=False, load_pcs=False):
11+
PhyRawIO.__init__(self,
12+
dirname=dirname,
13+
load_amplitudes=load_amplitudes,
14+
load_pcs=load_pcs)
1215
BaseFromRaw.__init__(self, dirname)

neo/rawio/phyrawio.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
import csv
1818
import ast
19+
import warnings
1920

2021

2122
class PhyRawIO(BaseRawIO):
@@ -35,9 +36,11 @@ class PhyRawIO(BaseRawIO):
3536
extensions = []
3637
rawmode = 'one-dir'
3738

38-
def __init__(self, dirname=''):
39+
def __init__(self, dirname='', load_amplitudes=False, load_pcs=False):
3940
BaseRawIO.__init__(self)
4041
self.dirname = dirname
42+
self.load_pcs = load_pcs
43+
self.load_amplitudes = load_amplitudes
4144

4245
def _source_name(self):
4346
return self.dirname
@@ -53,16 +56,24 @@ def _parse_header(self):
5356
else:
5457
self._spike_clusters = self._spike_templates
5558

56-
# TODO: Add this when array_annotations are ready
57-
# if (phy_folder / 'amplitudes.npy').is_file():
58-
# amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))
59-
# else:
60-
# amplitudes = np.ones(len(spike_times))
61-
#
62-
# if (phy_folder / 'pc_features.npy').is_file():
63-
# pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))
64-
# else:
65-
# pc_features = None
59+
self._amplitudes = None
60+
if self.load_amplitudes:
61+
if (phy_folder / 'amplitudes.npy').is_file():
62+
self._amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))
63+
else:
64+
warnings.warn('Amplitudes requested but "amplitudes.npy"'
65+
'not found in the data folder.')
66+
67+
self._pc_features = None
68+
self._pc_feature_ind = None
69+
if self.load_pcs:
70+
if ((phy_folder / 'pc_features.npy').is_file()
71+
and (phy_folder / 'pc_feature_ind.npy').is_file()):
72+
self._pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))
73+
self._pc_feature_ind = np.squeeze(np.load(phy_folder / 'pc_feature_ind.npy'))
74+
else:
75+
warnings.warn('PCs requested but "pc_features.npy" and/or'
76+
'"pc_feature_ind.npy" not found in the data folder.')
6677

6778
# SEE: https://stackoverflow.com/questions/4388626/
6879
# python-safe-eval-string-to-bool-int-float-none-string
@@ -150,6 +161,30 @@ def _parse_header(self):
150161
annotation_dict[property_name]
151162
break
152163

164+
cluster_mask = (self._spike_clusters == clust_id).flatten()
165+
166+
current_templates = self._spike_templates[cluster_mask].flatten()
167+
unique_templates = np.unique(current_templates)
168+
spiketrain_an['templates'] = unique_templates
169+
spiketrain_an['__array_annotations__']['templates'] = current_templates
170+
171+
if self._amplitudes is not None:
172+
spiketrain_an['__array_annotations__']['amplitudes'] = \
173+
self._amplitudes[cluster_mask]
174+
175+
if self._pc_features is not None:
176+
current_pc_features = self._pc_features[cluster_mask]
177+
_, num_pcs, num_pc_channels = current_pc_features.shape
178+
for pc_idx in range(num_pcs):
179+
for channel_idx in range(num_pc_channels):
180+
key = 'channel{channel_idx}_pc{pc_idx}'.format(channel_idx=channel_idx,
181+
pc_idx=pc_idx)
182+
spiketrain_an['__array_annotations__'][key] = \
183+
current_pc_features[:, pc_idx, channel_idx]
184+
185+
if self._pc_feature_ind is not None:
186+
spiketrain_an['pc_feature_ind'] = self._pc_feature_ind[unique_templates]
187+
153188
def _segment_t_start(self, block_index, seg_index):
154189
assert block_index == 0
155190
return self._t_start

0 commit comments

Comments
 (0)