Skip to content

Commit 7254820

Browse files
authored
Merge pull request #997 from rjurkus/fix/phyio_annotations
[phyio] Fixed spiketrain annotations
2 parents a6c1435 + 4984b33 commit 7254820

File tree

2 files changed

+54
-17
lines changed

2 files changed

+54
-17
lines changed

neo/rawio/phyrawio.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
12-
_spike_channel_dtype, _event_channel_dtype)
12+
_spike_channel_dtype, _event_channel_dtype)
1313

1414
import numpy as np
1515
from pathlib import Path
@@ -88,7 +88,8 @@ def _parse_header(self):
8888
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
8989

9090
signal_channels = []
91-
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
91+
signal_channels = np.array(signal_channels,
92+
dtype=_signal_channel_dtype)
9293

9394
spike_channels = []
9495
for i, clust_id in enumerate(clust_ids):
@@ -132,6 +133,9 @@ def _parse_header(self):
132133
for index, clust_id in enumerate(clust_ids):
133134
spiketrain_an = seg_ann['spikes'][index]
134135

136+
# Add cluster_id annotation
137+
spiketrain_an['cluster_id'] = clust_id
138+
135139
# Loop over list of list of dict and annotate each st
136140
for annotation_list in annotation_lists:
137141
clust_key, property_name = tuple(annotation_list[0].
@@ -172,8 +176,8 @@ def _spike_count(self, block_index, seg_index, spike_channel_index):
172176
nb_spikes = np.sum(mask)
173177
return nb_spikes
174178

175-
def _get_spike_timestamps(self, block_index, seg_index, spike_channel_index,
176-
t_start, t_stop):
179+
def _get_spike_timestamps(self, block_index, seg_index,
180+
spike_channel_index, t_start, t_stop):
177181
assert block_index == 0
178182
assert seg_index == 0
179183

@@ -183,7 +187,8 @@ def _get_spike_timestamps(self, block_index, seg_index, spike_channel_index,
183187

184188
if t_start is not None:
185189
start_frame = int(t_start * self._sampling_frequency)
186-
spike_timestamps = spike_timestamps[spike_timestamps >= start_frame]
190+
spike_timestamps = \
191+
spike_timestamps[spike_timestamps >= start_frame]
187192
if t_stop is not None:
188193
end_frame = int(t_stop * self._sampling_frequency)
189194
spike_timestamps = spike_timestamps[spike_timestamps < end_frame]
@@ -195,8 +200,8 @@ def _rescale_spike_timestamp(self, spike_timestamps, dtype):
195200
spike_times /= self._sampling_frequency
196201
return spike_times
197202

198-
def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index,
199-
t_start, t_stop):
203+
def _get_spike_raw_waveforms(self, block_index, seg_index,
204+
spike_channel_index, t_start, t_stop):
200205
return None
201206

202207
def _event_count(self, block_index, seg_index, event_channel_index):
@@ -215,14 +220,30 @@ def _rescale_epoch_duration(self, raw_duration, dtype):
215220
@staticmethod
216221
def _parse_tsv_or_csv_to_list_of_dict(filename):
217222
list_of_dict = list()
223+
letter_pattern = re.compile('[a-zA-Z]')
224+
float_pattern = re.compile(r'\d*\.')
218225
with open(filename) as csvfile:
219226
if filename.suffix == '.csv':
220227
reader = csv.DictReader(csvfile, delimiter=',')
221228
elif filename.suffix == '.tsv':
222229
reader = csv.DictReader(csvfile, delimiter='\t')
223230
else:
224231
raise ValueError("Function parses only .csv or .tsv files")
232+
line = 0
233+
225234
for row in reader:
235+
if line == 0:
236+
key1, key2 = tuple(row.keys())
237+
# Convert cluster ID to int
238+
row[key1] = int(row[key1])
239+
# Convert strings without letters
240+
if letter_pattern.match(row[key2]) is None:
241+
if float_pattern.match(row[key2]) is None:
242+
row[key2] = int(row[key2])
243+
else:
244+
row[key2] = float(row[key2])
245+
226246
list_of_dict.append(row)
247+
line += 1
227248

228249
return list_of_dict

neo/test/rawiotest/test_phyrawio.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,26 @@ def test_csv_tsv_parser_with_csv(self):
3131
csv_tempfile = Path(tempfile.gettempdir()).joinpath('test.csv')
3232
with open(csv_tempfile, 'w') as csv_file:
3333
csv_writer = csv.writer(csv_file, delimiter=',')
34-
csv_writer.writerow(['Header 1', 'Header 2'])
35-
csv_writer.writerow(['Value 1', 'Value 2'])
34+
csv_writer.writerow(['cluster_id', 'some_annotation'])
35+
csv_writer.writerow([1, 'Good'])
36+
csv_writer.writerow([2, 10])
37+
csv_writer.writerow([3, 1.23])
3638

3739
# the parser in PhyRawIO runs csv.DictReader to parse the file
3840
# csv.DictReader for python version 3.6+ returns list of OrderedDict
3941
if (3, 6) <= sys.version_info < (3, 8):
40-
target = [OrderedDict({'Header 1': 'Value 1',
41-
'Header 2': 'Value 2'})]
42+
target = [OrderedDict({'cluster_id': 1,
43+
'some_annotation': 'Good'}),
44+
OrderedDict({'cluster_id': 2,
45+
'some_annotation': 10}),
46+
OrderedDict({'cluster_id': 3,
47+
'some_annotation': 1.23})]
4248

4349
# csv.DictReader for python version 3.8+ returns list of dict
4450
elif sys.version_info >= (3, 8):
45-
target = [{'Header 1': 'Value 1', 'Header 2': 'Value 2'}]
51+
target = [{'cluster_id': 1, 'some_annotation': 'Good'},
52+
{'cluster_id': 2, 'some_annotation': 10},
53+
{'cluster_id': 3, 'some_annotation': 1.23}]
4654

4755
list_of_dict = PhyRawIO._parse_tsv_or_csv_to_list_of_dict(csv_tempfile)
4856

@@ -52,18 +60,26 @@ def test_csv_tsv_parser_with_tsv(self):
5260
tsv_tempfile = Path(tempfile.gettempdir()).joinpath('test.tsv')
5361
with open(tsv_tempfile, 'w') as tsv_file:
5462
tsv_writer = csv.writer(tsv_file, delimiter='\t')
55-
tsv_writer.writerow(['Header 1', 'Header 2'])
56-
tsv_writer.writerow(['Value 1', 'Value 2'])
63+
tsv_writer.writerow(['cluster_id', 'some_annotation'])
64+
tsv_writer.writerow([1, 'Good'])
65+
tsv_writer.writerow([2, 10])
66+
tsv_writer.writerow([3, 1.23])
5767

5868
# the parser in PhyRawIO runs csv.DictReader to parse the file
5969
# csv.DictReader for python version 3.6+ returns list of OrderedDict
6070
if (3, 6) <= sys.version_info < (3, 8):
61-
target = [OrderedDict({'Header 1': 'Value 1',
62-
'Header 2': 'Value 2'})]
71+
target = [OrderedDict({'cluster_id': 1,
72+
'some_annotation': 'Good'}),
73+
OrderedDict({'cluster_id': 2,
74+
'some_annotation': 10}),
75+
OrderedDict({'cluster_id': 3,
76+
'some_annotation': 1.23})]
6377

6478
# csv.DictReader for python version 3.8+ returns list of dict
6579
elif sys.version_info >= (3, 8):
66-
target = [{'Header 1': 'Value 1', 'Header 2': 'Value 2'}]
80+
target = [{'cluster_id': 1, 'some_annotation': 'Good'},
81+
{'cluster_id': 2, 'some_annotation': 10},
82+
{'cluster_id': 3, 'some_annotation': 1.23}]
6783

6884
list_of_dict = PhyRawIO._parse_tsv_or_csv_to_list_of_dict(tsv_tempfile)
6985

0 commit comments

Comments
 (0)