Skip to content

Commit 758dc07

Browse files
author
sprenger
committed
[TDTIO] Switch to pathlib and permit to load single block
1 parent d99f146 commit 758dc07

File tree

2 files changed

+82
-36
lines changed

2 files changed

+82
-36
lines changed

neo/rawio/tdtrawio.py

Lines changed: 80 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,49 +28,73 @@
2828
import numpy as np
2929
import os
3030
import re
31+
import warnings
3132
from collections import OrderedDict
33+
from pathlib import Path
3234

3335

3436
class TdtRawIO(BaseRawIO):
3537
rawmode = 'one-dir'
3638

3739
def __init__(self, dirname='', sortname=''):
3840
"""
39-
'sortname' is used to specify the external sortcode generated by offline spike sorting.
40-
if sortname=='PLX', there should be a ./sort/PLX/*.SortResult file in the tdt block,
41-
which stores the sortcode for every spike; defaults to '',
42-
which uses the original online sort.
41+
Initialize reader for one or multiple TDT data blocks.
42+
43+
dirname (str, pathlib.Path):
44+
tank-directory of a dataset to be read as multiple segments OR single file of dataset.
45+
In the latter case only the corresponding segment will considered.
46+
sortname (str):
47+
'sortname' is used to specify the external sortcode generated by offline spike sorting.
48+
if sortname=='PLX', there should be a ./sort/PLX/*.SortResult file in the tdt block,
49+
which stores the sortcode for every spike
50+
Default: '', uses the original online sort.
51+
52+
4353
"""
4454
BaseRawIO.__init__(self)
45-
dirname = str(dirname)
46-
if dirname.endswith('/'):
47-
dirname = dirname[:-1]
48-
self.dirname = dirname
55+
dirname = Path(dirname)
56+
if dirname.is_dir():
57+
self.dirname = Path(dirname)
58+
self.tdt_block_mode = 'multi'
59+
elif dirname.is_file():
60+
# in single tdt block mode the dirname also contains the block prefix
61+
self.dirname = dirname.with_suffix('')
62+
self.tdt_block_mode = 'single'
4963

5064
self.sortname = sortname
5165

5266
def _source_name(self):
5367
return self.dirname
5468

55-
def _parse_header(self):
56-
57-
tankname = os.path.basename(self.dirname)
69+
def _get_filestem(self, segment_name=''):
70+
if self.tdt_block_mode == 'multi':
71+
return self.dirname / segment_name / f'{self.dirname.name}_{segment_name}'
72+
else:
73+
return self.dirname
5874

75+
def _parse_header(self):
5976
segment_names = []
60-
for segment_name in os.listdir(self.dirname):
61-
path = os.path.join(self.dirname, segment_name)
62-
if is_tdtblock(path):
63-
segment_names.append(segment_name)
77+
if self.tdt_block_mode == 'multi':
78+
tankname = self.dirname.stem
79+
for path in self.dirname.iterdir():
80+
if is_tdtblock(path):
81+
segment_names.append(path.stem)
82+
83+
# if no block structure was detected, check if current dir contains a set of data
84+
elif is_tdtblock(self.dirname.parent):
85+
segment_names.append(str(self.dirname.stem))
86+
tankname = None
6487

6588
nb_segment = len(segment_names)
89+
if nb_segment == 0:
90+
warnings.warn(f'Could not find any data set belonging to {self.dirname}')
6691

6792
# TBK (channel info)
6893
info_channel_groups = None
6994
for seg_index, segment_name in enumerate(segment_names):
70-
path = os.path.join(self.dirname, segment_name)
7195

7296
# TBK contain channels
73-
tbk_filename = os.path.join(path, tankname + '_' + segment_name + '.Tbk')
97+
tbk_filename = self._get_filestem(segment_name).with_suffix('.Tbk')
7498
_info_channel_groups = read_tbk(tbk_filename)
7599
if info_channel_groups is None:
76100
info_channel_groups = _info_channel_groups
@@ -81,9 +105,8 @@ def _parse_header(self):
81105
# TEV (mixed data)
82106
self._tev_datas = []
83107
for seg_index, segment_name in enumerate(segment_names):
84-
path = os.path.join(self.dirname, segment_name)
85-
tev_filename = os.path.join(path, tankname + '_' + segment_name + '.tev')
86-
if os.path.exists(tev_filename):
108+
tev_filename = self._get_filestem(segment_name).with_suffix('.tev')
109+
if tev_filename.exists():
87110
tev_data = np.memmap(tev_filename, mode='r', offset=0, dtype='uint8')
88111
else:
89112
tev_data = None
@@ -94,8 +117,7 @@ def _parse_header(self):
94117
self._seg_t_starts = []
95118
self._seg_t_stops = []
96119
for seg_index, segment_name in enumerate(segment_names):
97-
path = os.path.join(self.dirname, segment_name)
98-
tsq_filename = os.path.join(path, tankname + '_' + segment_name + '.tsq')
120+
tsq_filename = self._get_filestem(segment_name).with_suffix('.tsq')
99121
tsq = np.fromfile(tsq_filename, dtype=tsq_dtype)
100122
self._tsq.append(tsq)
101123
# Start and stop times are only found in the second
@@ -115,9 +137,13 @@ def _parse_header(self):
115137
# (generated after offline sorting)
116138
if self.sortname != '':
117139
try:
118-
for file in os.listdir(os.path.join(path, 'sort', sortname)):
140+
if self.tdt_block_mode == 'multi':
141+
path = self.dirname
142+
else:
143+
path = self.dirname.parent
144+
for file in os.listdir(path / 'sort' / self.sortname):
119145
if file.endswith(".SortResult"):
120-
sortresult_filename = os.path.join(path, 'sort', sortname, file)
146+
sortresult_filename = path / 'sort' / self.sortname / file
121147
# get new sortcode
122148
newsortcode = np.fromfile(sortresult_filename, 'int8')[
123149
1024:] # first 1024 bytes are header
@@ -181,15 +207,22 @@ def _parse_header(self):
181207
assert self._sigs_lengths[seg_index][stream_index] == size
182208

183209
# signal start time, relative to start of segment
184-
t_start = data_index['timestamp'][0]
210+
if len(data_index['timestamp']):
211+
t_start = data_index['timestamp'][0]
212+
else:
213+
t_start = None
185214
if stream_index not in self._sigs_t_start[seg_index]:
186215
self._sigs_t_start[seg_index][stream_index] = t_start
187216
else:
188217
assert self._sigs_t_start[seg_index][stream_index] == t_start
189218

190219
# sampling_rate and dtype
191-
_sampling_rate = float(data_index['frequency'][0])
192-
_dtype = data_formats[data_index['dataformat'][0]]
220+
if len(data_index):
221+
_sampling_rate = float(data_index['frequency'][0])
222+
_dtype = data_formats[data_index['dataformat'][0]]
223+
else:
224+
_sampling_rate = np.nan
225+
_dtype = type(None)
193226
if sampling_rate is None:
194227
sampling_rate = _sampling_rate
195228
dtype = _dtype
@@ -202,11 +235,23 @@ def _parse_header(self):
202235
assert dtype == _dtype, 'sampling is changing!!!'
203236

204237
# data buffer test if SEV file exists otherwise TEV
205-
path = os.path.join(self.dirname, segment_name)
206-
sev_filename = os.path.join(path, tankname + '_' + segment_name + '_'
207-
+ info['StoreName'].decode('ascii')
208-
+ '_ch' + str(chan_id) + '.sev')
209-
if os.path.exists(sev_filename):
238+
# path = self.dirname / segment_name
239+
if self.tdt_block_mode == 'multi':
240+
# for multi block datasets the names of sev files are fixed
241+
store = info['StoreName'].decode('ascii')
242+
sev_stem = tankname + '_' + segment_name + '_' + store + '_ch' + str(chan_id)
243+
sev_filename = (path / sev_stem).with_suffix('.sev')
244+
else:
245+
# for single block datasets the exact name of sev files in not known
246+
sev_regex = f".*_ch{chan_id}.sev"
247+
sev_filename = list(self.dirname.parent.glob(str(sev_regex)))
248+
249+
# in case non or multiple sev files are found for current stream + channel
250+
if len(sev_filename) != 1:
251+
warnings.warn(f'Could not identify sev file for channel {chan_id}.')
252+
sev_filename = None
253+
254+
if (sev_filename is not None) and sev_filename.exists():
210255
data = np.memmap(sev_filename, mode='r', offset=0, dtype='uint8')
211256
else:
212257
data = self._tev_datas[seg_index]
@@ -526,10 +571,10 @@ def read_tbk(tbk_filename):
526571
def is_tdtblock(blockpath):
527572
"""Is tha path a TDT block (=neo.Segment) ?"""
528573
file_ext = list()
529-
if os.path.isdir(blockpath):
574+
if blockpath.is_dir():
530575
# for every file, get extension, convert to lowercase and append
531-
for file in os.listdir(blockpath):
532-
file_ext.append(os.path.splitext(file)[1].lower())
576+
for file in blockpath.iterdir():
577+
file_ext.append(file.suffix.lower())
533578

534579
file_ext = set(file_ext)
535580
tdt_ext = {'.tbk', '.tdx', '.tev', '.tsq'}

neo/test/rawiotest/test_tdtrawio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ class TestTdtRawIO(BaseTestRawIO, unittest.TestCase, ):
1010
'tdt'
1111
]
1212
entities_to_test = [
13-
'tdt/aep_05'
13+
'tdt/aep_05',
14+
'tdt/aep_05/Block-1/aep_05_Block-1.Tdx'
1415
]
1516

1617

0 commit comments

Comments
 (0)