Skip to content

Commit 20a753f

Browse files
committed
Fix tests
1 parent 4cc4c79 commit 20a753f

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

neo/rawio/spikeglxrawio.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
290290
return raw_signals
291291

292292
def _event_count(self, event_channel_idx, block_index=None, seg_index=None):
293-
timestamps, _, _ = self._get_event_timestamps(block_index, seg_index, event_channel_index,
294-
None, None)
293+
timestamps, _, _ = self._get_event_timestamps(block_index, seg_index, event_channel_idx, None, None)
295294
return timestamps.size
296295

297296
def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start=None, t_stop=None):
@@ -304,7 +303,7 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
304303
ch_idx = 7 - int(channel[2:]) # They are in the reverse order
305304
this_stream = event_data[:,ch_idx]
306305
this_rising = np.where(np.diff(this_stream)==1)[0] + 1
307-
this_falling = np.where(np.diff(this_stream)==255)[0] + 1 #behcause the data is in unsigned 8 bit, -1 = 255!
306+
this_falling = np.where(np.diff(this_stream)==255)[0] + 1 # because the data is in unsigned 8 bit, -1 = 255!
308307
if len(this_rising) > 0:
309308
timestamps.extend(this_rising)
310309
labels.extend([channel + ' ON']*len(this_rising))

neo/test/rawiotest/test_spikeglxrawio.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from neo.rawio.spikeglxrawio import SpikeGLXRawIO
88
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
9-
import numpy
9+
import numpy as np
1010

1111
class TestSpikeGLXRawIO(BaseTestRawIO, unittest.TestCase):
1212
rawioclass = SpikeGLXRawIO
@@ -29,7 +29,6 @@ class TestSpikeGLXRawIO(BaseTestRawIO, unittest.TestCase):
2929
"spikeglx/NP2_with_sync",
3030
"spikeglx/NP2_no_sync",
3131
"spikeglx/NP2_subset_with_sync",
32-
"spikeglx/DigitalChannelTest_g0",
3332
]
3433

3534
def test_with_location(self):
@@ -87,19 +86,19 @@ def test_subset_with_sync(self):
8786
assert chunk.shape[1] == 120
8887

8988
def test_nidq_digital_channel(self):
90-
rawio_digital = SpikeGLXRawIO("spikeglx/DigitalChannelTest_g0")
89+
rawio_digital = SpikeGLXRawIO(self.get_local_path("spikeglx/DigitalChannelTest_g0"))
9190
rawio_digital.parse_header()
9291
# This data should have 8 event channels
9392
assert(np.shape(rawio_digital.header['event_channels'])[0] == 8)
9493

9594
# Channel 0 in this data will have sync pulses at 1 Hz, let's confirm that
96-
all_events = rawio_digital.get_event_timestamps(0,0,0)
95+
all_events = rawio_digital.get_event_timestamps(0, 0, 0)
9796
on_events = np.where(all_events[2] == 'XD0 ON')
98-
on_ts = this_events[0][on_events]
99-
on_diff = np.unique(np.diff(on_ts))
100-
for diff in this_on_diff:
101-
error = 0.0001*rawio_digital.get_signal_sampling_rate()
102-
assert abs(diff - rawio_digital.get_signal_sampling_rate()) < error
97+
on_ts = all_events[0][on_events]
98+
on_ts_scaled = rawio_digital.rescale_event_timestamp(on_ts)
99+
on_diff = np.diff(on_ts_scaled)
100+
atol = 0.001
101+
assert np.allclose(on_diff, 1, atol=atol)
103102

104103
if __name__ == "__main__":
105104
unittest.main()

0 commit comments

Comments
 (0)