Skip to content

Commit 7cddeb8

Browse files
authored
Merge branch 'NeuralEnsemble:master' into cboulay/br_filespec_3_0
2 parents 375e474 + 90606a6 commit 7cddeb8

File tree

8 files changed

+167
-47
lines changed

8 files changed

+167
-47
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
---
2+
name: Bug report
3+
about: Create a report to help us fix problems
4+
title: ''
5+
labels: bug
6+
assignees: ''
7+
8+
---
9+
10+
**Describe the bug**
11+
A clear and concise description of what the bug is.
12+
13+
**To Reproduce**
14+
Steps to reproduce the behaviour, preferably providing a simple code example (if the error happens in the middle of some complex code, please try to find a simpler, minimal example that demonstrates the error), and showing the full traceback.
15+
16+
If the error occurs when reading a file that you can't share publicly, please let us know, and we'll get in touch to discuss sharing it privately.
17+
18+
**Expected behaviour**
19+
If the bug is incorrect behaviour, rather than an unexpected Exception, please give a clear and concise description of what you expected to happen.
20+
21+
**Environment:**
22+
- OS: [e.g. macOS, Linux, Windows]
23+
- Python version
24+
- Neo version
25+
- NumPy version
26+
27+
**Additional context**
28+
Add any other context about the problem here.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
---
2+
name: Confusing documentation
3+
about: Let us know if the documentation is confusing or incorrect
4+
title: ''
5+
labels: Documentation
6+
assignees: ''
7+
8+
---
9+
10+
**Which page is the problem on?**
11+
The URL of the documentation page where the problem is, and either copy-paste the confusing text (for a short section of text), or give the first few and last few words (for a long section).
12+
13+
**What is the problem?**
14+
Is the documentation (a) confusing or (b) incorrect? In what way?
15+
16+
**Suggestions for fixing the problem**
17+
If the documentation is confusing, can you suggest an improvement? If the documentation is incorrect, what should it say instead?
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
name: Feature request
3+
about: Suggest an idea to improve Neo
4+
title: ''
5+
labels: enhancement
6+
assignees: ''
7+
8+
---
9+
10+
**Is your feature request related to a problem? Please describe.**
11+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12+
13+
**Describe the solution you'd like**
14+
A clear and concise description of what you want to happen.
15+
16+
**Describe alternatives you've considered**
17+
A clear and concise description of any alternative solutions or features you've considered.
18+
19+
**Additional context**
20+
Add any other context or screenshots about the feature request here.

neo/core/epoch.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
'''
77

88
from copy import deepcopy, copy
9+
from numbers import Number
910

1011
import numpy as np
1112
import quantities as pq
@@ -53,10 +54,12 @@ class Epoch(DataObject):
5354
*Required attributes/properties*:
5455
:times: (quantity array 1D, numpy array 1D or list) The start times
5556
of each time period.
56-
:durations: (quantity array 1D, numpy array 1D, list, or quantity scalar)
57+
:durations: (quantity array 1D, numpy array 1D, list, quantity scalar or float)
5758
The length(s) of each time period.
58-
If a scalar, the same value is used for all time periods.
59+
If a scalar/float, the same value is used for all time periods.
5960
:labels: (numpy.array 1D dtype='U' or list) Names or labels for the time periods.
61+
:units: (quantity units or str) Required if the times is a list or NumPy
62+
array, not if it is a :class:`Quantity`
6063
6164
*Recommended attributes/properties*:
6265
:name: (str) A label for the dataset,
@@ -88,8 +91,10 @@ def __new__(cls, times=None, durations=None, labels=None, units=None, name=None,
8891
raise ValueError("Times array has more than 1 dimension")
8992
if isinstance(durations, (list, tuple)):
9093
durations = np.array(durations)
91-
if durations is None:
94+
elif durations is None:
9295
durations = np.array([]) * pq.s
96+
elif isinstance(durations, Number):
97+
durations = durations * np.ones(times.shape)
9398
elif durations.size != times.size:
9499
if durations.size == 1:
95100
durations = durations * np.ones_like(times.magnitude)

neo/io/phyio.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ class PhyIO(PhyRawIO, BaseFromRaw):
77
description = "Phy IO"
88
mode = 'dir'
99

10-
def __init__(self, dirname):
11-
PhyRawIO.__init__(self, dirname=dirname)
10+
def __init__(self, dirname, load_amplitudes=False, load_pcs=False):
11+
PhyRawIO.__init__(self,
12+
dirname=dirname,
13+
load_amplitudes=load_amplitudes,
14+
load_pcs=load_pcs)
1215
BaseFromRaw.__init__(self, dirname)

neo/rawio/blackrockrawio.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,7 @@ def _parse_header(self):
458458
# Find maximal and minimal time for each nev segment
459459
for k, (data, ev_ids) in self.nev_data.items():
460460
for i in np.unique(ev_ids):
461-
mask = [ev_ids == i]
462-
curr_data = data[mask]
461+
curr_data = data[ev_ids == i]
463462
if curr_data.size > 0:
464463
if max(curr_data['timestamp']) >= max_nev_times.get(i, 0):
465464
max_nev_times[i] = max(curr_data['timestamp'])

neo/rawio/phyrawio.py

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
import csv
1818
import ast
19+
import warnings
1920

2021

2122
class PhyRawIO(BaseRawIO):
@@ -35,9 +36,11 @@ class PhyRawIO(BaseRawIO):
3536
extensions = []
3637
rawmode = 'one-dir'
3738

38-
def __init__(self, dirname=''):
39+
def __init__(self, dirname='', load_amplitudes=False, load_pcs=False):
3940
BaseRawIO.__init__(self)
4041
self.dirname = dirname
42+
self.load_pcs = load_pcs
43+
self.load_amplitudes = load_amplitudes
4144

4245
def _source_name(self):
4346
return self.dirname
@@ -53,16 +56,24 @@ def _parse_header(self):
5356
else:
5457
self._spike_clusters = self._spike_templates
5558

56-
# TODO: Add this when array_annotations are ready
57-
# if (phy_folder / 'amplitudes.npy').is_file():
58-
# amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))
59-
# else:
60-
# amplitudes = np.ones(len(spike_times))
61-
#
62-
# if (phy_folder / 'pc_features.npy').is_file():
63-
# pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))
64-
# else:
65-
# pc_features = None
59+
self._amplitudes = None
60+
if self.load_amplitudes:
61+
if (phy_folder / 'amplitudes.npy').is_file():
62+
self._amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))
63+
else:
64+
warnings.warn('Amplitudes requested but "amplitudes.npy"'
65+
'not found in the data folder.')
66+
67+
self._pc_features = None
68+
self._pc_feature_ind = None
69+
if self.load_pcs:
70+
if ((phy_folder / 'pc_features.npy').is_file()
71+
and (phy_folder / 'pc_feature_ind.npy').is_file()):
72+
self._pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))
73+
self._pc_feature_ind = np.squeeze(np.load(phy_folder / 'pc_feature_ind.npy'))
74+
else:
75+
warnings.warn('PCs requested but "pc_features.npy" and/or'
76+
'"pc_feature_ind.npy" not found in the data folder.')
6677

6778
# SEE: https://stackoverflow.com/questions/4388626/
6879
# python-safe-eval-string-to-bool-int-float-none-string
@@ -138,17 +149,41 @@ def _parse_header(self):
138149

139150
# Loop over list of list of dict and annotate each st
140151
for annotation_list in annotation_lists:
141-
clust_key, property_name = tuple(annotation_list[0].
142-
keys())
143-
if property_name == 'KSLabel':
144-
annotation_name = 'quality'
145-
else:
146-
annotation_name = property_name.lower()
147-
for annotation_dict in annotation_list:
148-
if int(annotation_dict[clust_key]) == clust_id:
149-
spiketrain_an[annotation_name] = \
150-
annotation_dict[property_name]
151-
break
152+
clust_key, *property_names = tuple(annotation_list[0].keys())
153+
for property_name in property_names:
154+
if property_name == 'KSLabel':
155+
annotation_name = 'quality'
156+
else:
157+
annotation_name = property_name.lower()
158+
for annotation_dict in annotation_list:
159+
if int(annotation_dict[clust_key]) == clust_id:
160+
spiketrain_an[annotation_name] = \
161+
annotation_dict[property_name]
162+
break
163+
164+
cluster_mask = (self._spike_clusters == clust_id).flatten()
165+
166+
current_templates = self._spike_templates[cluster_mask].flatten()
167+
unique_templates = np.unique(current_templates)
168+
spiketrain_an['templates'] = unique_templates
169+
spiketrain_an['__array_annotations__']['templates'] = current_templates
170+
171+
if self._amplitudes is not None:
172+
spiketrain_an['__array_annotations__']['amplitudes'] = \
173+
self._amplitudes[cluster_mask]
174+
175+
if self._pc_features is not None:
176+
current_pc_features = self._pc_features[cluster_mask]
177+
_, num_pcs, num_pc_channels = current_pc_features.shape
178+
for pc_idx in range(num_pcs):
179+
for channel_idx in range(num_pc_channels):
180+
key = 'channel{channel_idx}_pc{pc_idx}'.format(channel_idx=channel_idx,
181+
pc_idx=pc_idx)
182+
spiketrain_an['__array_annotations__'][key] = \
183+
current_pc_features[:, pc_idx, channel_idx]
184+
185+
if self._pc_feature_ind is not None:
186+
spiketrain_an['pc_feature_ind'] = self._pc_feature_ind[unique_templates]
152187

153188
def _segment_t_start(self, block_index, seg_index):
154189
assert block_index == 0
@@ -221,7 +256,7 @@ def _rescale_epoch_duration(self, raw_duration, dtype):
221256
def _parse_tsv_or_csv_to_list_of_dict(filename):
222257
list_of_dict = list()
223258
letter_pattern = re.compile('[a-zA-Z]')
224-
float_pattern = re.compile(r'\d*\.')
259+
float_pattern = re.compile(r'-?\d*\.')
225260
with open(filename) as csvfile:
226261
if filename.suffix == '.csv':
227262
reader = csv.DictReader(csvfile, delimiter=',')
@@ -233,15 +268,19 @@ def _parse_tsv_or_csv_to_list_of_dict(filename):
233268

234269
for row in reader:
235270
if line == 0:
236-
key1, key2 = tuple(row.keys())
271+
cluster_id_key, *annotation_keys = tuple(row.keys())
237272
# Convert cluster ID to int
238-
row[key1] = int(row[key1])
273+
row[cluster_id_key] = int(row[cluster_id_key])
239274
# Convert strings without letters
240-
if letter_pattern.match(row[key2]) is None:
241-
if float_pattern.match(row[key2]) is None:
242-
row[key2] = int(row[key2])
243-
else:
244-
row[key2] = float(row[key2])
275+
for key in annotation_keys:
276+
value = row[key]
277+
if not len(value):
278+
row[key] = None
279+
elif letter_pattern.match(value) is None:
280+
if float_pattern.match(value) is None:
281+
row[key] = int(value)
282+
else:
283+
row[key] = float(value)
245284

246285
list_of_dict.append(row)
247286
line += 1

neo/test/rawiotest/test_phyrawio.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,35 @@ def test_csv_tsv_parser_with_csv(self):
3131
csv_tempfile = Path(tempfile.gettempdir()).joinpath('test.csv')
3232
with open(csv_tempfile, 'w') as csv_file:
3333
csv_writer = csv.writer(csv_file, delimiter=',')
34-
csv_writer.writerow(['cluster_id', 'some_annotation'])
35-
csv_writer.writerow([1, 'Good'])
36-
csv_writer.writerow([2, 10])
37-
csv_writer.writerow([3, 1.23])
34+
csv_writer.writerow(['cluster_id', 'some_annotation', 'some_other_annotation'])
35+
csv_writer.writerow([1, 'Good', 'Bad'])
36+
csv_writer.writerow([2, 10, -2])
37+
csv_writer.writerow([3, 1.23, -0.38])
3838

3939
# the parser in PhyRawIO runs csv.DictReader to parse the file
4040
# csv.DictReader for python version 3.6+ returns list of OrderedDict
4141
if (3, 6) <= sys.version_info < (3, 8):
4242
target = [OrderedDict({'cluster_id': 1,
43-
'some_annotation': 'Good'}),
43+
'some_annotation': 'Good',
44+
'some_other_annotation': 'Bad'}),
4445
OrderedDict({'cluster_id': 2,
45-
'some_annotation': 10}),
46+
'some_annotation': 10,
47+
'some_other_annotation': -2}),
4648
OrderedDict({'cluster_id': 3,
47-
'some_annotation': 1.23})]
49+
'some_annotation': 1.23,
50+
'some_other_annotation': -0.38})]
4851

4952
# csv.DictReader for python version 3.8+ returns list of dict
5053
elif sys.version_info >= (3, 8):
51-
target = [{'cluster_id': 1, 'some_annotation': 'Good'},
52-
{'cluster_id': 2, 'some_annotation': 10},
53-
{'cluster_id': 3, 'some_annotation': 1.23}]
54+
target = [{'cluster_id': 1,
55+
'some_annotation': 'Good',
56+
'some_other_annotation': 'Bad'},
57+
{'cluster_id': 2,
58+
'some_annotation': 10,
59+
'some_other_annotation': -2},
60+
{'cluster_id': 3,
61+
'some_annotation': 1.23,
62+
'some_other_annotation': -0.38}]
5463

5564
list_of_dict = PhyRawIO._parse_tsv_or_csv_to_list_of_dict(csv_tempfile)
5665

0 commit comments

Comments
 (0)