Skip to content

Commit 90606a6

Browse files
Merge pull request #1153 from INM-6/enh/phyrawio_read_multicolumn_csvs
Read multi-column csvs in PhyIO
2 parents ff190ba + 075110b commit 90606a6

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

neo/rawio/phyrawio.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,17 @@ def _parse_header(self):
149149

150150
# Loop over list of list of dict and annotate each st
151151
for annotation_list in annotation_lists:
152-
clust_key, property_name = tuple(annotation_list[0].
153-
keys())
154-
if property_name == 'KSLabel':
155-
annotation_name = 'quality'
156-
else:
157-
annotation_name = property_name.lower()
158-
for annotation_dict in annotation_list:
159-
if int(annotation_dict[clust_key]) == clust_id:
160-
spiketrain_an[annotation_name] = \
161-
annotation_dict[property_name]
162-
break
152+
clust_key, *property_names = tuple(annotation_list[0].keys())
153+
for property_name in property_names:
154+
if property_name == 'KSLabel':
155+
annotation_name = 'quality'
156+
else:
157+
annotation_name = property_name.lower()
158+
for annotation_dict in annotation_list:
159+
if int(annotation_dict[clust_key]) == clust_id:
160+
spiketrain_an[annotation_name] = \
161+
annotation_dict[property_name]
162+
break
163163

164164
cluster_mask = (self._spike_clusters == clust_id).flatten()
165165

@@ -256,7 +256,7 @@ def _rescale_epoch_duration(self, raw_duration, dtype):
256256
def _parse_tsv_or_csv_to_list_of_dict(filename):
257257
list_of_dict = list()
258258
letter_pattern = re.compile('[a-zA-Z]')
259-
float_pattern = re.compile(r'\d*\.')
259+
float_pattern = re.compile(r'-?\d*\.')
260260
with open(filename) as csvfile:
261261
if filename.suffix == '.csv':
262262
reader = csv.DictReader(csvfile, delimiter=',')
@@ -268,15 +268,19 @@ def _parse_tsv_or_csv_to_list_of_dict(filename):
268268

269269
for row in reader:
270270
if line == 0:
271-
key1, key2 = tuple(row.keys())
271+
cluster_id_key, *annotation_keys = tuple(row.keys())
272272
# Convert cluster ID to int
273-
row[key1] = int(row[key1])
273+
row[cluster_id_key] = int(row[cluster_id_key])
274274
# Convert strings without letters
275-
if letter_pattern.match(row[key2]) is None:
276-
if float_pattern.match(row[key2]) is None:
277-
row[key2] = int(row[key2])
278-
else:
279-
row[key2] = float(row[key2])
275+
for key in annotation_keys:
276+
value = row[key]
277+
if not len(value):
278+
row[key] = None
279+
elif letter_pattern.match(value) is None:
280+
if float_pattern.match(value) is None:
281+
row[key] = int(value)
282+
else:
283+
row[key] = float(value)
280284

281285
list_of_dict.append(row)
282286
line += 1

neo/test/rawiotest/test_phyrawio.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,35 @@ def test_csv_tsv_parser_with_csv(self):
3131
csv_tempfile = Path(tempfile.gettempdir()).joinpath('test.csv')
3232
with open(csv_tempfile, 'w') as csv_file:
3333
csv_writer = csv.writer(csv_file, delimiter=',')
34-
csv_writer.writerow(['cluster_id', 'some_annotation'])
35-
csv_writer.writerow([1, 'Good'])
36-
csv_writer.writerow([2, 10])
37-
csv_writer.writerow([3, 1.23])
34+
csv_writer.writerow(['cluster_id', 'some_annotation', 'some_other_annotation'])
35+
csv_writer.writerow([1, 'Good', 'Bad'])
36+
csv_writer.writerow([2, 10, -2])
37+
csv_writer.writerow([3, 1.23, -0.38])
3838

3939
# the parser in PhyRawIO runs csv.DictReader to parse the file
4040
# csv.DictReader for python version 3.6+ returns list of OrderedDict
4141
if (3, 6) <= sys.version_info < (3, 8):
4242
target = [OrderedDict({'cluster_id': 1,
43-
'some_annotation': 'Good'}),
43+
'some_annotation': 'Good',
44+
'some_other_annotation': 'Bad'}),
4445
OrderedDict({'cluster_id': 2,
45-
'some_annotation': 10}),
46+
'some_annotation': 10,
47+
'some_other_annotation': -2}),
4648
OrderedDict({'cluster_id': 3,
47-
'some_annotation': 1.23})]
49+
'some_annotation': 1.23,
50+
'some_other_annotation': -0.38})]
4851

4952
# csv.DictReader for python version 3.8+ returns list of dict
5053
elif sys.version_info >= (3, 8):
51-
target = [{'cluster_id': 1, 'some_annotation': 'Good'},
52-
{'cluster_id': 2, 'some_annotation': 10},
53-
{'cluster_id': 3, 'some_annotation': 1.23}]
54+
target = [{'cluster_id': 1,
55+
'some_annotation': 'Good',
56+
'some_other_annotation': 'Bad'},
57+
{'cluster_id': 2,
58+
'some_annotation': 10,
59+
'some_other_annotation': -2},
60+
{'cluster_id': 3,
61+
'some_annotation': 1.23,
62+
'some_other_annotation': -0.38}]
5463

5564
list_of_dict = PhyRawIO._parse_tsv_or_csv_to_list_of_dict(csv_tempfile)
5665

0 commit comments

Comments
 (0)