Skip to content

Commit 1603a7c

Browse files
author
sprenger
committed
[TDT] add test for single tdt block loading
1 parent 758dc07 commit 1603a7c

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

neo/rawio/tdtrawio.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def __init__(self, dirname='', sortname=''):
6060
# in single tdt block mode the dirname also contains the block prefix
6161
self.dirname = dirname.with_suffix('')
6262
self.tdt_block_mode = 'single'
63+
else:
64+
raise ValueError(f'No data folder or file found for {dirname}')
6365

6466
self.sortname = sortname
6567

neo/test/rawiotest/test_tdtrawio.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import unittest
2+
from pathlib import Path
3+
from numpy.testing import assert_array_equal, assert_
24

35
from neo.rawio.tdtrawio import TdtRawIO
46
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
@@ -14,6 +16,45 @@ class TestTdtRawIO(BaseTestRawIO, unittest.TestCase, ):
1416
'tdt/aep_05/Block-1/aep_05_Block-1.Tdx'
1517
]
1618

19+
def test_invalid_dirname(self):
20+
invalid_name = 'random_non_existant_tdt_filename'
21+
assert not Path(invalid_name).exists()
22+
23+
with self.assertRaises(ValueError):
24+
TdtRawIO(invalid_name)
25+
26+
def test_compare_load_multi_single_block(self):
27+
dirname = self.get_local_path('tdt/aep_05')
28+
filename = self.get_local_path('tdt/aep_05/Block-1/aep_05_Block-1.Tdx')
29+
30+
io_single = TdtRawIO(filename)
31+
io_multi = TdtRawIO(dirname)
32+
33+
io_single.parse_header()
34+
io_multi.parse_header()
35+
36+
self.assertEqual(io_single.tdt_block_mode, 'single')
37+
self.assertEqual(io_multi.tdt_block_mode, 'multi')
38+
39+
self.assertEqual(io_single.block_count(), 1)
40+
self.assertEqual(io_multi.block_count(), 1)
41+
42+
self.assertEqual(io_single.segment_count(0), 1)
43+
self.assertEqual(io_multi.segment_count(0), 2)
44+
45+
# compare header infos
46+
assert_array_equal(io_single.header['signal_streams'], io_multi.header['signal_streams'])
47+
assert_array_equal(io_single.header['signal_channels'], io_multi.header['signal_channels'])
48+
assert_array_equal(io_single.header['event_channels'], io_multi.header['event_channels'])
49+
50+
# not all spiking channels are present in first tdt block (segment)
51+
for spike_channel in io_single.header['spike_channels']:
52+
self.assertIn(spike_channel, io_multi.header['spike_channels'])
53+
54+
# check that extracted signal chunks are identical
55+
assert_array_equal(io_single.get_analogsignal_chunk(0, 0, 0, 100, 0),
56+
io_multi.get_analogsignal_chunk(0, 0, 0, 100, 0))
57+
1758

1859
if __name__ == "__main__":
1960
unittest.main()

0 commit comments

Comments
 (0)