Skip to content

Commit 3853c75

Browse files
committed
Merge branch 'rjurkus-feature/phyio'
2 parents b804d8e + a7e148f commit 3853c75

File tree

7 files changed

+429
-0
lines changed

7 files changed

+429
-0
lines changed

doc/source/authors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ and may not be the current affiliation of a contributor.
5555
* Peter N Steinmetz [22]
5656
* Shashwat Sridhar
5757
* Alessio Buccino [23]
58+
* Regimantas Jurkus [13]
5859

5960
1. Centre de Recherche en Neuroscience de Lyon, CNRS UMR5292 - INSERM U1028 - Universite Claude Bernard Lyon 1
6061
2. Unité de Neuroscience, Information et Complexité, CNRS UPR 3293, Gif-sur-Yvette, France

neo/io/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
* :attr:`NixIO`
4444
* :attr:`NSDFIO`
4545
* :attr:`OpenEphysIO`
46+
* :attr:`PhyIO`
4647
* :attr:`PickleIO`
4748
* :attr:`PlexonIO`
4849
* :attr:`RawBinarySignalIO`
@@ -176,6 +177,10 @@
176177
177178
.. autoattribute:: extensions
178179
180+
.. autoclass:: neo.io.PhyIO
181+
182+
.. autoattribute:: extensions
183+
179184
.. autoclass:: neo.io.PickleIO
180185
181186
.. autoattribute:: extensions
@@ -271,6 +276,7 @@
271276
from neo.io.nixio_fr import NixIO as NixIOFr
272277
from neo.io.nsdfio import NSDFIO
273278
from neo.io.openephysio import OpenEphysIO
279+
from neo.io.phyio import PhyIO
274280
from neo.io.pickleio import PickleIO
275281
from neo.io.plexonio import PlexonIO
276282
from neo.io.rawbinarysignalio import RawBinarySignalIO
@@ -315,6 +321,7 @@
315321
NeuroshareIO,
316322
NSDFIO,
317323
OpenEphysIO,
324+
PhyIO,
318325
PickleIO,
319326
PlexonIO,
320327
RawBinarySignalIO,

neo/io/phyio.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from neo.io.basefromrawio import BaseFromRaw
2+
from neo.rawio.phyrawio import PhyRawIO
3+
4+
5+
class PhyIO(PhyRawIO, BaseFromRaw):
6+
name = 'Phy IO'
7+
description = "Phy IO"
8+
mode = 'dir'
9+
10+
def __init__(self, dirname):
11+
PhyRawIO.__init__(self, dirname=dirname)
12+
BaseFromRaw.__init__(self, dirname)

neo/rawio/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
* :attr:`NeuroScopeRawIO`
2626
* :attr:`NIXRawIO`
2727
* :attr:`OpenEphysRawIO`
28+
* :attr:'PhyRawIO'
2829
* :attr:`PlexonRawIO`
2930
* :attr:`RawBinarySignalRawIO`
3031
* :attr:`RawMCSRawIO`
@@ -87,6 +88,10 @@
8788
8889
.. autoattribute:: extensions
8990
91+
.. autoclass:: neo.rawio.PhyRawIO
92+
93+
.. autoattribute:: extensions
94+
9095
.. autoclass:: neo.rawio.PlexonRawIO
9196
9297
.. autoattribute:: extensions
@@ -136,6 +141,7 @@
136141
from neo.rawio.neuroscoperawio import NeuroScopeRawIO
137142
from neo.rawio.nixrawio import NIXRawIO
138143
from neo.rawio.openephysrawio import OpenEphysRawIO
144+
from neo.rawio.phyrawio import PhyRawIO
139145
from neo.rawio.plexonrawio import PlexonRawIO
140146
from neo.rawio.rawbinarysignalrawio import RawBinarySignalRawIO
141147
from neo.rawio.rawmcsrawio import RawMCSRawIO
@@ -159,6 +165,7 @@
159165
NeuroScopeRawIO,
160166
NIXRawIO,
161167
OpenEphysRawIO,
168+
PhyRawIO,
162169
PlexonRawIO,
163170
RawBinarySignalRawIO,
164171
RawMCSRawIO,

neo/rawio/phyrawio.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
"""
2+
PhyRawIO is a class to handle Phy spike sorting data.
3+
Ported from:
4+
https://github.com/SpikeInterface/spikeextractors/blob/
5+
f20b1219eba9d3330d5d7cd7ce8d8924a255b8c2/spikeextractors/
6+
extractors/phyextractors/phyextractors.py
7+
8+
Author: Regimantas Jurkus
9+
"""
10+
11+
from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype,
12+
_event_channel_dtype)
13+
14+
import numpy as np
15+
from pathlib import Path
16+
import re
17+
import csv
18+
import ast
19+
20+
21+
class PhyRawIO(BaseRawIO):
22+
"""
23+
Class for reading Phy data.
24+
25+
Usage:
26+
>>> import neo.rawio
27+
>>> r = neo.rawio.PhyRawIO(dirname='/dir/to/phy/folder')
28+
>>> r.parse_header()
29+
>>> print(r)
30+
>>> spike_timestamp = r.get_spike_timestamps(block_index=0,
31+
... seg_index=0, unit_index=0, t_start=None, t_stop=None)
32+
>>> spike_times = r.rescale_spike_timestamp(spike_timestamp, 'float64')
33+
34+
"""
35+
extensions = []
36+
rawmode = 'one-dir'
37+
38+
def __init__(self, dirname=''):
39+
BaseRawIO.__init__(self)
40+
self.dirname = dirname
41+
42+
def _source_name(self):
43+
return self.dirname
44+
45+
def _parse_header(self):
46+
phy_folder = Path(self.dirname)
47+
48+
self._spike_times = np.load(phy_folder / 'spike_times.npy')
49+
self._spike_templates = np.load(phy_folder / 'spike_templates.npy')
50+
51+
if (phy_folder / 'spike_clusters.npy').is_file():
52+
self._spike_clusters = np.load(phy_folder / 'spike_clusters.npy')
53+
else:
54+
self._spike_clusters = self._spike_templates
55+
56+
# TODO: Add this when array_annotations are ready
57+
# if (phy_folder / 'amplitudes.npy').is_file():
58+
# amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))
59+
# else:
60+
# amplitudes = np.ones(len(spike_times))
61+
#
62+
# if (phy_folder / 'pc_features.npy').is_file():
63+
# pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))
64+
# else:
65+
# pc_features = None
66+
67+
# SEE: https://stackoverflow.com/questions/4388626/
68+
# python-safe-eval-string-to-bool-int-float-none-string
69+
if (phy_folder / 'params.py').is_file():
70+
with (phy_folder / 'params.py').open('r') as f:
71+
contents = f.read()
72+
metadata = dict()
73+
contents = contents.replace('\n', ' ')
74+
pattern = re.compile(r'(\S*)[\s]?=[\s]?(\S*)')
75+
elements = pattern.findall(contents)
76+
for key, value in elements:
77+
metadata[key.lower()] = ast.literal_eval(value)
78+
79+
self._sampling_frequency = metadata['sample_rate']
80+
81+
clust_ids = np.unique(self._spike_clusters)
82+
self.unit_labels = list(clust_ids)
83+
84+
self._t_start = 0.
85+
self._t_stop = max(self._spike_times).item() / self._sampling_frequency
86+
87+
sig_channels = []
88+
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
89+
90+
unit_channels = []
91+
for i, clust_id in enumerate(clust_ids):
92+
unit_name = f'unit {clust_id}'
93+
unit_id = f'{clust_id}'
94+
wf_units = ''
95+
wf_gain = 0
96+
wf_offset = 0.
97+
wf_left_sweep = 0
98+
wf_sampling_rate = 0
99+
unit_channels.append((unit_name, unit_id, wf_units, wf_gain,
100+
wf_offset, wf_left_sweep, wf_sampling_rate))
101+
unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
102+
103+
event_channels = []
104+
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
105+
106+
self.header = {}
107+
self.header['nb_block'] = 1
108+
self.header['nb_segment'] = [1]
109+
self.header['signal_channels'] = sig_channels
110+
self.header['unit_channels'] = unit_channels
111+
self.header['event_channels'] = event_channels
112+
113+
self._generate_minimal_annotations()
114+
115+
csv_tsv_files = [x for x in phy_folder.iterdir() if
116+
x.suffix == '.csv' or x.suffix == '.tsv']
117+
118+
# annotation_lists is list of list of dict (python==3.8)
119+
# or list of list of ordered dict (python==3.6)
120+
# SEE: https://docs.python.org/3/library/csv.html#csv.DictReader
121+
annotation_lists = [self._parse_tsv_or_csv_to_list_of_dict(file)
122+
for file in csv_tsv_files]
123+
124+
bl_ann = self.raw_annotations['blocks'][0]
125+
bl_ann['name'] = "Block #0"
126+
seg_ann = bl_ann['segments'][0]
127+
seg_ann['name'] = 'Seg #0 Block #0'
128+
for index, clust_id in enumerate(clust_ids):
129+
spiketrain_an = seg_ann['units'][index]
130+
131+
# Loop over list of list of dict and annotate each st
132+
for annotation_list in annotation_lists:
133+
clust_key, property_name = tuple(annotation_list[0].
134+
keys())
135+
if property_name == 'KSLabel':
136+
annotation_name = 'quality'
137+
else:
138+
annotation_name = property_name.lower()
139+
for annotation_dict in annotation_list:
140+
if int(annotation_dict[clust_key]) == clust_id:
141+
spiketrain_an[annotation_name] = \
142+
annotation_dict[property_name]
143+
break
144+
145+
def _segment_t_start(self, block_index, seg_index):
146+
assert block_index == 0
147+
return self._t_start
148+
149+
def _segment_t_stop(self, block_index, seg_index):
150+
assert block_index == 0
151+
return self._t_stop
152+
153+
def _get_signal_size(self, block_index, seg_index, channel_indexes=None):
154+
return None
155+
156+
def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
157+
return None
158+
159+
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
160+
channel_indexes):
161+
return None
162+
163+
def _spike_count(self, block_index, seg_index, unit_index):
164+
assert block_index == 0
165+
spikes = self._spike_clusters
166+
unit_label = self.unit_labels[unit_index]
167+
mask = spikes == unit_label
168+
nb_spikes = np.sum(mask)
169+
return nb_spikes
170+
171+
def _get_spike_timestamps(self, block_index, seg_index, unit_index,
172+
t_start, t_stop):
173+
assert block_index == 0
174+
assert seg_index == 0
175+
176+
unit_label = self.unit_labels[unit_index]
177+
mask = self._spike_clusters == unit_label
178+
spike_timestamps = self._spike_times[mask]
179+
180+
if t_start is not None:
181+
start_frame = int(t_start * self._sampling_frequency)
182+
spike_timestamps = spike_timestamps[spike_timestamps >=
183+
start_frame]
184+
if t_stop is not None:
185+
end_frame = int(t_stop * self._sampling_frequency)
186+
spike_timestamps = spike_timestamps[spike_timestamps < end_frame]
187+
188+
return spike_timestamps
189+
190+
def _rescale_spike_timestamp(self, spike_timestamps, dtype):
191+
spike_times = spike_timestamps.astype(dtype)
192+
spike_times /= self._sampling_frequency
193+
return spike_times
194+
195+
def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index,
196+
t_start, t_stop):
197+
return None
198+
199+
def _event_count(self, block_index, seg_index, event_channel_index):
200+
return None
201+
202+
def _get_event_timestamps(self, block_index, seg_index,
203+
event_channel_index, t_start, t_stop):
204+
return None
205+
206+
def _rescale_event_timestamp(self, event_timestamps, dtype):
207+
return None
208+
209+
def _rescale_epoch_duration(self, raw_duration, dtype):
210+
return None
211+
212+
@staticmethod
213+
def _parse_tsv_or_csv_to_list_of_dict(filename):
214+
list_of_dict = list()
215+
with open(filename) as csvfile:
216+
if filename.suffix == '.csv':
217+
reader = csv.DictReader(csvfile, delimiter=',')
218+
elif filename.suffix == '.tsv':
219+
reader = csv.DictReader(csvfile, delimiter='\t')
220+
else:
221+
raise ValueError("Function parses only .csv or .tsv files")
222+
for row in reader:
223+
list_of_dict.append(row)
224+
225+
return list_of_dict

0 commit comments

Comments
 (0)