Skip to content

Commit a95a140

Browse files
authored
Merge pull request #1057 from JuliaSprenger/enh/tdt-block
Extend tdtrawio to support single tdt block reading and empty streams
2 parents a9a3386 + 6d59707 commit a95a140

File tree

2 files changed

+127
-36
lines changed

2 files changed

+127
-36
lines changed

neo/rawio/tdtrawio.py

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,49 +28,75 @@
2828
import numpy as np
2929
import os
3030
import re
31+
import warnings
3132
from collections import OrderedDict
33+
from pathlib import Path
3234

3335

3436
class TdtRawIO(BaseRawIO):
3537
rawmode = 'one-dir'
3638

3739
def __init__(self, dirname='', sortname=''):
3840
"""
39-
'sortname' is used to specify the external sortcode generated by offline spike sorting.
40-
if sortname=='PLX', there should be a ./sort/PLX/*.SortResult file in the tdt block,
41-
which stores the sortcode for every spike; defaults to '',
42-
which uses the original online sort.
41+
Initialize reader for one or multiple TDT data blocks.
42+
43+
dirname (str, pathlib.Path):
44+
tank-directory of a dataset to be read as multiple segments OR single file of dataset.
45+
In the latter case only the corresponding segment will considered.
46+
sortname (str):
47+
'sortname' is used to specify the external sortcode generated by offline spike sorting.
48+
if sortname=='PLX', there should be a ./sort/PLX/*.SortResult file in the tdt block,
49+
which stores the sortcode for every spike
50+
Default: '', uses the original online sort.
51+
52+
4353
"""
4454
BaseRawIO.__init__(self)
45-
dirname = str(dirname)
46-
if dirname.endswith('/'):
47-
dirname = dirname[:-1]
48-
self.dirname = dirname
55+
dirname = Path(dirname)
56+
if dirname.is_dir():
57+
self.dirname = Path(dirname)
58+
self.tdt_block_mode = 'multi'
59+
elif dirname.is_file():
60+
# in single tdt block mode the dirname also contains the block prefix
61+
self.dirname = dirname.with_suffix('')
62+
self.tdt_block_mode = 'single'
63+
else:
64+
raise ValueError(f'No data folder or file found for {dirname}')
4965

5066
self.sortname = sortname
5167

5268
def _source_name(self):
5369
return self.dirname
5470

55-
def _parse_header(self):
56-
57-
tankname = os.path.basename(self.dirname)
71+
def _get_filestem(self, segment_name=''):
72+
if self.tdt_block_mode == 'multi':
73+
return self.dirname / segment_name / f'{self.dirname.name}_{segment_name}'
74+
else:
75+
return self.dirname
5876

77+
def _parse_header(self):
5978
segment_names = []
60-
for segment_name in os.listdir(self.dirname):
61-
path = os.path.join(self.dirname, segment_name)
62-
if is_tdtblock(path):
63-
segment_names.append(segment_name)
79+
if self.tdt_block_mode == 'multi':
80+
tankname = self.dirname.stem
81+
for path in self.dirname.iterdir():
82+
if is_tdtblock(path):
83+
segment_names.append(path.stem)
84+
85+
# if no block structure was detected, check if current dir contains a set of data
86+
elif is_tdtblock(self.dirname.parent):
87+
segment_names.append(str(self.dirname.stem))
88+
tankname = None
6489

6590
nb_segment = len(segment_names)
91+
if nb_segment == 0:
92+
warnings.warn(f'Could not find any data set belonging to {self.dirname}')
6693

6794
# TBK (channel info)
6895
info_channel_groups = None
6996
for seg_index, segment_name in enumerate(segment_names):
70-
path = os.path.join(self.dirname, segment_name)
7197

7298
# TBK contain channels
73-
tbk_filename = os.path.join(path, tankname + '_' + segment_name + '.Tbk')
99+
tbk_filename = self._get_filestem(segment_name).with_suffix('.Tbk')
74100
_info_channel_groups = read_tbk(tbk_filename)
75101
if info_channel_groups is None:
76102
info_channel_groups = _info_channel_groups
@@ -81,9 +107,8 @@ def _parse_header(self):
81107
# TEV (mixed data)
82108
self._tev_datas = []
83109
for seg_index, segment_name in enumerate(segment_names):
84-
path = os.path.join(self.dirname, segment_name)
85-
tev_filename = os.path.join(path, tankname + '_' + segment_name + '.tev')
86-
if os.path.exists(tev_filename):
110+
tev_filename = self._get_filestem(segment_name).with_suffix('.tev')
111+
if tev_filename.exists():
87112
tev_data = np.memmap(tev_filename, mode='r', offset=0, dtype='uint8')
88113
else:
89114
tev_data = None
@@ -94,8 +119,7 @@ def _parse_header(self):
94119
self._seg_t_starts = []
95120
self._seg_t_stops = []
96121
for seg_index, segment_name in enumerate(segment_names):
97-
path = os.path.join(self.dirname, segment_name)
98-
tsq_filename = os.path.join(path, tankname + '_' + segment_name + '.tsq')
122+
tsq_filename = self._get_filestem(segment_name).with_suffix('.tsq')
99123
tsq = np.fromfile(tsq_filename, dtype=tsq_dtype)
100124
self._tsq.append(tsq)
101125
# Start and stop times are only found in the second
@@ -115,9 +139,13 @@ def _parse_header(self):
115139
# (generated after offline sorting)
116140
if self.sortname != '':
117141
try:
118-
for file in os.listdir(os.path.join(path, 'sort', sortname)):
142+
if self.tdt_block_mode == 'multi':
143+
path = self.dirname
144+
else:
145+
path = self.dirname.parent
146+
for file in os.listdir(path / 'sort' / self.sortname):
119147
if file.endswith(".SortResult"):
120-
sortresult_filename = os.path.join(path, 'sort', sortname, file)
148+
sortresult_filename = path / 'sort' / self.sortname / file
121149
# get new sortcode
122150
newsortcode = np.fromfile(sortresult_filename, 'int8')[
123151
1024:] # first 1024 bytes are header
@@ -181,15 +209,24 @@ def _parse_header(self):
181209
assert self._sigs_lengths[seg_index][stream_index] == size
182210

183211
# signal start time, relative to start of segment
184-
t_start = data_index['timestamp'][0]
212+
if len(data_index['timestamp']):
213+
t_start = data_index['timestamp'][0]
214+
else:
215+
# if no signal present use segment t_start as dummy value
216+
t_start = self._seg_t_starts[seg_index]
185217
if stream_index not in self._sigs_t_start[seg_index]:
186218
self._sigs_t_start[seg_index][stream_index] = t_start
187219
else:
188220
assert self._sigs_t_start[seg_index][stream_index] == t_start
189221

190222
# sampling_rate and dtype
191-
_sampling_rate = float(data_index['frequency'][0])
192-
_dtype = data_formats[data_index['dataformat'][0]]
223+
if len(data_index):
224+
_sampling_rate = float(data_index['frequency'][0])
225+
_dtype = data_formats[data_index['dataformat'][0]]
226+
else:
227+
# if no signal present use dummy values
228+
_sampling_rate = 1.
229+
_dtype = int
193230
if sampling_rate is None:
194231
sampling_rate = _sampling_rate
195232
dtype = _dtype
@@ -202,11 +239,23 @@ def _parse_header(self):
202239
assert dtype == _dtype, 'sampling is changing!!!'
203240

204241
# data buffer test if SEV file exists otherwise TEV
205-
path = os.path.join(self.dirname, segment_name)
206-
sev_filename = os.path.join(path, tankname + '_' + segment_name + '_'
207-
+ info['StoreName'].decode('ascii')
208-
+ '_ch' + str(chan_id) + '.sev')
209-
if os.path.exists(sev_filename):
242+
# path = self.dirname / segment_name
243+
if self.tdt_block_mode == 'multi':
244+
# for multi block datasets the names of sev files are fixed
245+
store = info['StoreName'].decode('ascii')
246+
sev_stem = f'{tankname}_{segment_name}_{store}_ch{chan_id}'
247+
sev_filename = (path / sev_stem).with_suffix('.sev')
248+
else:
249+
# for single block datasets the exact name of sev files in not known
250+
sev_regex = f".*_ch{chan_id}.sev"
251+
sev_filename = list(self.dirname.parent.glob(str(sev_regex)))
252+
253+
# in case non or multiple sev files are found for current stream + channel
254+
if len(sev_filename) != 1:
255+
warnings.warn(f'Could not identify sev file for channel {chan_id}.')
256+
sev_filename = None
257+
258+
if (sev_filename is not None) and sev_filename.exists():
210259
data = np.memmap(sev_filename, mode='r', offset=0, dtype='uint8')
211260
else:
212261
data = self._tev_datas[seg_index]
@@ -526,10 +575,10 @@ def read_tbk(tbk_filename):
526575
def is_tdtblock(blockpath):
527576
"""Is tha path a TDT block (=neo.Segment) ?"""
528577
file_ext = list()
529-
if os.path.isdir(blockpath):
578+
if blockpath.is_dir():
530579
# for every file, get extension, convert to lowercase and append
531-
for file in os.listdir(blockpath):
532-
file_ext.append(os.path.splitext(file)[1].lower())
580+
for file in blockpath.iterdir():
581+
file_ext.append(file.suffix.lower())
533582

534583
file_ext = set(file_ext)
535584
tdt_ext = {'.tbk', '.tdx', '.tev', '.tsq'}

neo/test/rawiotest/test_tdtrawio.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import unittest
2+
from pathlib import Path
3+
from numpy.testing import assert_array_equal, assert_
24

35
from neo.rawio.tdtrawio import TdtRawIO
46
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
@@ -10,9 +12,49 @@ class TestTdtRawIO(BaseTestRawIO, unittest.TestCase, ):
1012
'tdt'
1113
]
1214
entities_to_test = [
13-
'tdt/aep_05'
15+
'tdt/aep_05',
16+
'tdt/aep_05/Block-1/aep_05_Block-1.Tdx'
1417
]
1518

19+
def test_invalid_dirname(self):
20+
invalid_name = 'random_non_existant_tdt_filename'
21+
assert not Path(invalid_name).exists()
22+
23+
with self.assertRaises(ValueError):
24+
TdtRawIO(invalid_name)
25+
26+
def test_compare_load_multi_single_block(self):
27+
dirname = self.get_local_path('tdt/aep_05')
28+
filename = self.get_local_path('tdt/aep_05/Block-1/aep_05_Block-1.Tdx')
29+
30+
io_single = TdtRawIO(filename)
31+
io_multi = TdtRawIO(dirname)
32+
33+
io_single.parse_header()
34+
io_multi.parse_header()
35+
36+
self.assertEqual(io_single.tdt_block_mode, 'single')
37+
self.assertEqual(io_multi.tdt_block_mode, 'multi')
38+
39+
self.assertEqual(io_single.block_count(), 1)
40+
self.assertEqual(io_multi.block_count(), 1)
41+
42+
self.assertEqual(io_single.segment_count(0), 1)
43+
self.assertEqual(io_multi.segment_count(0), 2)
44+
45+
# compare header infos
46+
assert_array_equal(io_single.header['signal_streams'], io_multi.header['signal_streams'])
47+
assert_array_equal(io_single.header['signal_channels'], io_multi.header['signal_channels'])
48+
assert_array_equal(io_single.header['event_channels'], io_multi.header['event_channels'])
49+
50+
# not all spiking channels are present in first tdt block (segment)
51+
for spike_channel in io_single.header['spike_channels']:
52+
self.assertIn(spike_channel, io_multi.header['spike_channels'])
53+
54+
# check that extracted signal chunks are identical
55+
assert_array_equal(io_single.get_analogsignal_chunk(0, 0, 0, 100, 0),
56+
io_multi.get_analogsignal_chunk(0, 0, 0, 100, 0))
57+
1658

1759
if __name__ == "__main__":
1860
unittest.main()

0 commit comments

Comments
 (0)