Skip to content

Commit f3704d3

Browse files
committed
added tests and passing them
1 parent 43b86a7 commit f3704d3

File tree

2 files changed

+75
-18
lines changed

2 files changed

+75
-18
lines changed

neo/rawio/mearecrawio.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ class MEArecRawIO(BaseRawIO):
2020
"""
2121
Class for "reading" fake data from a MEArec file.
2222
23+
This class provides a convenient way to read data from a MEArec file.
24+
25+
Parameters
26+
----------
27+
filename : str
28+
The filename of the MEArec file to read.
29+
load_spiketrains : bool, optional
30+
Whether or not to load spike train data. Defaults to `True`.
31+
load_recordings : bool, optional
32+
Whether or not to load continuous recording data. Defaults to `True`.
33+
34+
2335
Usage:
2436
>>> import neo.rawio
2537
>>> r = neo.rawio.MEArecRawIO(filename='mearec.h5')
@@ -46,35 +58,43 @@ def _source_name(self):
4658
return self.filename
4759

4860
def _parse_header(self):
49-
import MEArec as mr
50-
load = ['channel_positions']
61+
load = []
5162
if self.load_recordings:
5263
load.append("recordings")
5364
if self.load_spiketrains:
5465
load.append("spiketrains")
5566

67+
import MEArec as mr
5668
self._recgen = mr.load_recordings(recordings=self.filename, return_h5_objects=True,
5769
check_suffix=False,
5870
load=load,
5971
load_waveforms=False)
60-
self._sampling_rate = self._recgen.info['recordings']['fs']
61-
62-
signal_streams = np.array([('Signals', '0')], dtype=_signal_stream_dtype)
72+
73+
self.info_dict = self._recgen.info
74+
self._sampling_rate = self.info_dict['recordings']['fs']
75+
self.duration_seconds = self.info_dict["recordings"]["duration"]
76+
self._num_frames = int(self._sampling_rate * self.duration_seconds)
77+
self._num_channels = np.sum(self.info_dict["electrodes"]["dim"])
78+
self._dtype = self.info_dict["recordings"]["dtype"]
79+
80+
signals = [('Signals', '0')]
81+
signal_streams = np.array(signals, dtype=_signal_stream_dtype)
6382

64-
sig_channels = []
6583
if self.load_recordings:
6684
self._recordings = self._recgen.recordings
67-
self._num_frames, self._num_channels = self._recordings.shape
68-
for c in range(self._num_channels):
69-
ch_name = 'ch{}'.format(c)
70-
chan_id = str(c + 1)
71-
sr = self._sampling_rate # Hz
72-
dtype = self._recordings.dtype
73-
units = 'uV'
74-
gain = 1.
75-
offset = 0.
76-
stream_id = '0'
77-
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))
85+
86+
sig_channels = []
87+
for c in range(self._num_channels):
88+
ch_name = 'ch{}'.format(c)
89+
chan_id = str(c + 1)
90+
sr = self._sampling_rate # Hz
91+
dtype = self._dtype
92+
units = 'uV'
93+
gain = 1.
94+
offset = 0.
95+
stream_id = '0'
96+
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))
97+
7898
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
7999

80100
# creating units channels
@@ -92,7 +112,8 @@ def _parse_header(self):
92112
wf_sampling_rate = self._sampling_rate
93113
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
94114
wf_offset, wf_left_sweep, wf_sampling_rate))
95-
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
115+
116+
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
96117

97118
event_channels = []
98119
event_channels = np.array(event_channels, dtype=_event_channel_dtype)

neo/test/rawiotest/test_mearecrawio.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,42 @@ class TestMEArecRawIO(BaseTestRawIO, unittest.TestCase, ):
2727
'mearec/mearec_test_10s.h5'
2828
]
2929

30+
def test_not_loading_recordings(self):
31+
32+
filename = self.entities_to_test[0]
33+
filename = self.get_local_path(filename)
34+
rawio = self.rawioclass(filename=filename, load_recordings=False)
35+
rawio.parse_header()
36+
37+
# Test that rawio does not have a _recordings attribute
38+
self.assertFalse(hasattr(rawio, '_recordings'))
39+
40+
# Test that calling get_spike_timestamps works
41+
rawio.get_spike_timestamps()
42+
43+
# Test that caling anlogsignal chunk raises the right error
44+
error_message = "Recordings not loaded. Set load_recordings=True in MEArecRawIO constructor"
45+
with self.assertRaises(AttributeError, msg=error_message):
46+
rawio.get_analogsignal_chunk()
47+
48+
49+
def test_not_loading_spiketrain(self):
50+
51+
filename = self.entities_to_test[0]
52+
filename = self.get_local_path(filename)
53+
rawio = self.rawioclass(filename=filename, load_spiketrains=False)
54+
rawio.parse_header()
55+
56+
# Test that rawio does not have a _spiketrains attribute
57+
self.assertFalse(hasattr(rawio, '_spiketrains'))
58+
59+
# Test that calling analogsignal chunk works
60+
rawio.get_analogsignal_chunk()
61+
62+
# Test that calling get_spike_timestamps raises an the right error
63+
error_message = "Spiketrains not loaded. Set load_spiketrains=True in MEArecRawIO constructor"
64+
with self.assertRaises(AttributeError, msg=error_message):
65+
rawio.get_spike_timestamps()
3066

3167
if __name__ == "__main__":
3268
unittest.main()

0 commit comments

Comments
 (0)