Skip to content

Commit f3dce46

Browse files
committed
update klustakwik tests to be reincluded in test runs
1 parent 4ff0a3e commit f3dce46

File tree

1 file changed

+84
-153
lines changed

1 file changed

+84
-153
lines changed

neo/test/iotest/test_klustakwikio.py

Lines changed: 84 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
Tests of neo.io.klustakwikio
33
"""
44

5-
import glob
6-
import os.path
7-
import sys
8-
import tempfile
5+
from pathlib import Path
6+
import shutil
97

108
import unittest
119

@@ -18,76 +16,64 @@
1816
from neo.io.klustakwikio import KlustaKwikIO
1917

2018

21-
class testFilenameParser(unittest.TestCase):
22-
"""Tests that filenames can be loaded with or without basename.
19+
class KlustaKwikTests(BaseTestIO, unittest.TestCase):
20+
ioclass = KlustaKwikIO
21+
entities_to_download = [
22+
'klustakwik'
23+
]
24+
entities_to_test = [
25+
'klustakwik/test2/base',
26+
'klustakwik/test2/base2',
27+
]
28+
29+
@classmethod
30+
def setUpClass(cls, *args, **kwargs):
31+
super(KlustaKwikTests, cls).setUpClass(*args, **kwargs)
32+
dirname = Path(cls.get_local_path('klustakwik'))
2333

24-
The test directory contains two basenames and some decoy files with
25-
malformed group numbers."""
34+
cls.session1 = dirname / 'test1'
35+
cls.session2 = dirname / 'test2'
36+
cls.session3 = dirname / 'test3'
37+
cls.tmp_session = dirname / 'tmp_session'
38+
cls.tmp_session.mkdir(exist_ok=True)
2639

27-
def setUp(self):
28-
self.dirname = os.path.join(tempfile.gettempdir(),
29-
'files_for_testing_neo',
30-
'klustakwik/test1')
31-
if not os.path.exists(self.dirname):
32-
raise unittest.SkipTest('data directory does not exist: ' +
33-
self.dirname)
40+
def tearDown(self) -> None:
41+
shutil.rmtree(self.tmp_session, ignore_errors=True)
42+
self.tmp_session.mkdir(exist_ok=True)
3443

35-
def test1(self):
44+
def test_load_by_basename(self):
3645
"""Tests that files can be loaded by basename"""
37-
kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename'))
38-
if not BaseTestIO.use_network:
39-
raise unittest.SkipTest("Requires download of data from the web")
46+
kio = KlustaKwikIO(dirname=self.session1 / 'basename')
4047
fetfiles = kio._fp.read_filenames('fet')
4148

4249
self.assertEqual(len(fetfiles), 2)
43-
self.assertEqual(os.path.abspath(fetfiles[0]),
44-
os.path.abspath(os.path.join(self.dirname,
45-
'basename.fet.0')))
46-
self.assertEqual(os.path.abspath(fetfiles[1]),
47-
os.path.abspath(os.path.join(self.dirname,
48-
'basename.fet.1')))
49-
50-
def test2(self):
51-
"""Tests that files are loaded even without basename"""
52-
pass
53-
54-
# this test is in flux, should probably have it default to
55-
# basename = os.path.split(dirname)[1] when dirname is a directory
56-
# dirname = os.path.normpath('./files_for_tests/klustakwik/test1')
57-
# kio = KlustaKwikIO(filename=dirname)
58-
# fetfiles = kio._fp.read_filenames('fet')
59-
60-
# It will just choose one of the two basenames, depending on which
61-
# is first, so just assert that it did something without error.
62-
# self.assertNotEqual(len(fetfiles), 0)
63-
64-
def test3(self):
50+
self.assertEqual(Path(fetfiles[0]).absolute(), self.session1 / 'basename.fet.0')
51+
self.assertEqual(Path(fetfiles[1]).absolute(), self.session1 / 'basename.fet.1')
52+
53+
# def test_load_without_basename(self):
54+
# """Tests that files are loaded even without basename"""
55+
# pass
56+
#
57+
# # this test is in flux, should probably have it default to
58+
# # basename = os.path.split(dirname)[1] when dirname is a directory)
59+
# kio = KlustaKwikIO(dirname=self.session1)
60+
# fetfiles = kio._fp.read_filenames('fet')
61+
#
62+
# # It will just choose one of the two basenames, depending on which
63+
# # is first, so just assert that it did something without error.
64+
# self.assertNotEqual(len(fetfiles), 0)
65+
66+
def test_load_by_basename_2(self):
6567
"""Tests that files can be loaded by basename2"""
66-
kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename2'))
67-
if not BaseTestIO.use_network:
68-
raise unittest.SkipTest("Requires download of data from the web")
68+
kio = KlustaKwikIO(dirname=self.session1 / 'basename2')
6969
clufiles = kio._fp.read_filenames('clu')
7070

7171
self.assertEqual(len(clufiles), 1)
72-
self.assertEqual(os.path.abspath(clufiles[1]),
73-
os.path.abspath(os.path.join(self.dirname,
74-
'basename2.clu.1')))
72+
self.assertEqual(Path(clufiles[1]).absolute(), self.session1 /'basename2.clu.1')
7573

76-
77-
class testRead(unittest.TestCase):
78-
"""Tests that data can be read from KlustaKwik files"""
79-
80-
def setUp(self):
81-
self.dirname = os.path.join(tempfile.gettempdir(),
82-
'files_for_testing_neo',
83-
'klustakwik/test2')
84-
if not os.path.exists(self.dirname):
85-
raise unittest.SkipTest('data directory does not exist: ' +
86-
self.dirname)
87-
88-
def test1(self):
74+
def test_read_data(self):
8975
"""Tests that data and metadata are read correctly"""
90-
kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base'),
76+
kio = KlustaKwikIO(dirname=self.session2 / 'base',
9177
sampling_rate=1000.)
9278
block = kio.read()[0]
9379
seg = block.segments[0]
@@ -122,9 +108,9 @@ def test1(self):
122108
self.assertTrue(np.all(seg.spiketrains[3].times == np.array([.050,
123109
.152])))
124110

125-
def test2(self):
111+
def test_default_cluster_id_0(self):
126112
"""Checks that cluster id autosets to 0 without clu file"""
127-
kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
113+
kio = KlustaKwikIO(dirname=self.session2 / 'base2',
128114
sampling_rate=1000.)
129115
block = kio.read()[0]
130116
seg = block.segments[0]
@@ -137,17 +123,7 @@ def test2(self):
137123
0.122,
138124
0.228])))
139125

140-
141-
class testWrite(unittest.TestCase):
142-
def setUp(self):
143-
self.dirname = os.path.join(tempfile.gettempdir(),
144-
'files_for_testing_neo',
145-
'klustakwik/test3')
146-
if not os.path.exists(self.dirname):
147-
raise unittest.SkipTest('data directory does not exist: ' +
148-
self.dirname)
149-
150-
def test1(self):
126+
def test_write_clu_and_fet(self):
151127
"""Create clu and fet files based on spiketrains in a block.
152128
153129
Checks that
@@ -192,70 +168,53 @@ def test1(self):
192168
st4.annotations['group'] = 2
193169
segment.spiketrains.append(st4)
194170

195-
# Create empty directory for writing
196-
delete_test_session()
197-
198171
# Create writer with default sampling rate
199-
kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base1'),
172+
kio = KlustaKwikIO(dirname=self.tmp_session / 'base1',
200173
sampling_rate=1000.)
201174
kio.write_block(block)
202175

203176
# Check files were created
204177
for fn in ['.fet.0', '.fet.1', '.clu.0', '.clu.1']:
205-
self.assertTrue(os.path.exists(os.path.join(self.dirname,
206-
'base1' + fn)))
178+
self.assertTrue((self.tmp_session / ('base1' + fn)).exists())
207179

208180
# Check files contain correct content
209181
# Spike times on group 0
210-
with open(os.path.join(self.dirname, 'base1.fet.0'), mode='r') as f:
182+
with open(self.tmp_session / 'base1.fet.0', mode='r') as f:
211183
data = f.readlines()
212184
data = [int(d) for d in data]
213185
self.assertEqual(data, [0, 2, 4, 6, 1, 3, 11, 106])
214186

215187
# Clusters on group 0
216-
with open(os.path.join(self.dirname, 'base1.clu.0'), mode='r') as f:
188+
with open(self.tmp_session / 'base1.clu.0', mode='r') as f:
217189
data = f.readlines()
218190
data = [int(d) for d in data]
219191
self.assertEqual(data, [2, 0, 0, 0, 1, 1, 1, 0])
220192

221193
# Spike times on group 1
222-
with open(os.path.join(self.dirname, 'base1.fet.1'), mode='r') as f:
194+
with open(self.tmp_session / 'base1.fet.1', mode='r') as f:
223195
data = f.readlines()
224196
data = [int(d) for d in data]
225197
self.assertEqual(data, [0, 50, 90, 100])
226198

227199
# Clusters on group 1
228-
with open(os.path.join(self.dirname, 'base1.clu.1')) as f:
200+
with open(self.tmp_session / 'base1.clu.1') as f:
229201
data = f.readlines()
230202
data = [int(d) for d in data]
231203
self.assertEqual(data, [1, -1, -1, -1])
232204

233205
# Spike times on group 2
234-
with open(os.path.join(self.dirname, 'base1.fet.2')) as f:
206+
with open(self.tmp_session / 'base1.fet.2') as f:
235207
data = f.readlines()
236208
data = [int(d) for d in data]
237209
self.assertEqual(data, [0, 5, 9])
238210

239211
# Clusters on group 2
240-
with open(os.path.join(self.dirname, 'base1.clu.2')) as f:
212+
with open(self.tmp_session / 'base1.clu.2') as f:
241213
data = f.readlines()
242214
data = [int(d) for d in data]
243215
self.assertEqual(data, [1, 0, 0])
244216

245-
# Empty out test session again
246-
delete_test_session()
247-
248-
249-
class testWriteWithFeatures(unittest.TestCase):
250-
def setUp(self):
251-
self.dirname = os.path.join(tempfile.gettempdir(),
252-
'files_for_testing_neo',
253-
'klustakwik/test4')
254-
if not os.path.exists(self.dirname):
255-
raise unittest.SkipTest('data directory does not exist: ' +
256-
self.dirname)
257-
258-
def test1(self):
217+
def test_write_clu_and_fet_1(self):
259218
"""Create clu and fet files based on spiketrains in a block.
260219
261220
Checks that
@@ -282,70 +241,42 @@ def test1(self):
282241
st1.annotations['waveform_features'] = wff
283242
segment.spiketrains.append(st1)
284243

285-
# Create empty directory for writing
286-
if not os.path.exists(self.dirname):
287-
os.mkdir(self.dirname)
288-
delete_test_session(self.dirname)
289-
290244
# Create writer
291-
kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
245+
kio = KlustaKwikIO(dirname=self.tmp_session / 'base2',
292246
sampling_rate=1000.)
293247
kio.write_block(block)
294248

295249
# Check files were created
296250
for fn in ['.fet.0', '.clu.0']:
297-
self.assertTrue(os.path.exists(os.path.join(self.dirname,
298-
'base2' + fn)))
251+
self.assertTrue(self.tmp_session / ('base2' + fn))
299252

300253
# Check files contain correct content
301-
fi = file(os.path.join(self.dirname, 'base2.fet.0'))
302-
303-
# first line is nbFeatures
304-
self.assertEqual(fi.readline(), '2\n')
305-
306-
# Now check waveforms and times are same
307-
data = fi.readlines()
308-
new_wff = []
309-
new_times = []
310-
for line in data:
311-
line_split = line.split()
312-
new_wff.append([float(val) for val in line_split[:-1]])
313-
new_times.append(int(line_split[-1]))
314-
self.assertEqual(new_times, [2, 4, 6])
315-
assert_arrays_almost_equal(wff, np.array(new_wff), .00001)
254+
with open(self.tmp_session / 'base2.fet.0') as fi:
255+
# first line is nbFeatures
256+
self.assertEqual(fi.readline(), '2\n')
257+
258+
# Now check waveforms and times are same
259+
data = fi.readlines()
260+
new_wff = []
261+
new_times = []
262+
for line in data:
263+
line_split = line.split()
264+
new_wff.append([float(val) for val in line_split[:-1]])
265+
new_times.append(int(line_split[-1]))
266+
self.assertEqual(new_times, [2, 4, 6])
267+
assert_arrays_almost_equal(wff, np.array(new_wff), .00001)
316268

317269
# Clusters on group 0
318-
data = file(os.path.join(self.dirname, 'base2.clu.0')).readlines()
319-
data = [int(d) for d in data]
320-
self.assertEqual(data, [1, 0, 0, 0])
321-
322-
# Now read the features and test same
323-
block = kio.read_block()
324-
train = block.segments[0].spiketrains[0]
325-
assert_arrays_almost_equal(wff, train.annotations['waveform_features'],
326-
.00001)
327-
328-
# Empty out test session again
329-
delete_test_session(self.dirname)
330-
331-
332-
class CommonTests(BaseTestIO, unittest.TestCase):
333-
ioclass = KlustaKwikIO
334-
entities_to_download = [
335-
'klustakwik'
336-
]
337-
entities_to_test = [
338-
'klustakwik/test2/base',
339-
'klustakwik/test2/base2',
340-
]
270+
with open(self.tmp_session / 'base2.clu.0') as fi:
271+
data = fi.readlines()
272+
data = [int(d) for d in data]
273+
self.assertEqual(data, [1, 0, 0, 0])
341274

342-
def delete_test_session(dirname=None):
343-
"""Removes all file in directory so we can test writing to it"""
344-
if dirname is None:
345-
dirname = os.path.join(os.path.dirname(__file__),
346-
'files_for_tests/klustakwik/test3')
347-
for fi in glob.glob(os.path.join(dirname, '*')):
348-
os.remove(fi)
275+
# Now read the features and test same
276+
block = kio.read_block()
277+
train = block.segments[0].spiketrains[0]
278+
assert_arrays_almost_equal(wff, train.annotations['waveform_features'],
279+
.00001)
349280

350281

351282
if __name__ == '__main__':

0 commit comments

Comments
 (0)