Skip to content

Commit 36474cb

Browse files
committed
black
1 parent 20a753f commit 36474cb

File tree

2 files changed

+52
-48
lines changed

2 files changed

+52
-48
lines changed

neo/rawio/spikeglxrawio.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -180,20 +180,20 @@ def _parse_header(self):
180180
event_channels = []
181181
# This is true only in case of 'nidq' stream
182182
for stream_name in stream_names:
183-
if 'nidq' in stream_name:
184-
info = self.signals_info_dict[0, stream_name]
185-
if len(info['digital_channels']) > 0:
183+
if "nidq" in stream_name:
184+
info = self.signals_info_dict[0, stream_name]
185+
if len(info["digital_channels"]) > 0:
186186
# add event channels
187-
for local_chan in info['digital_channels']:
187+
for local_chan in info["digital_channels"]:
188188
chan_name = local_chan
189-
chan_id = f'{stream_name}#{chan_name}'
190-
event_channels.append((chan_name, chan_id, 'event'))
191-
# add events_memmap
192-
data = np.memmap(info['bin_file'], dtype='int16', mode='r', offset=0, order='C')
193-
data = data.reshape(-1, info['num_chan'])
189+
chan_id = f"{stream_name}#{chan_name}"
190+
event_channels.append((chan_name, chan_id, "event"))
191+
# add events_memmap
192+
data = np.memmap(info["bin_file"], dtype="int16", mode="r", offset=0, order="C")
193+
data = data.reshape(-1, info["num_chan"])
194194
# The digital word is usually the last channel, after all the individual analog channels
195-
extracted_word = data[:,len(info['analog_channels'])]
196-
self._events_memmap = np.unpackbits(extracted_word.astype(np.uint8)[:,None], axis=1)
195+
extracted_word = data[:, len(info["analog_channels"])]
196+
self._events_memmap = np.unpackbits(extracted_word.astype(np.uint8)[:, None], axis=1)
197197
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
198198

199199
# No spikes
@@ -288,35 +288,37 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
288288
raw_signals = memmap[slice(i_start, i_stop), channel_selection]
289289

290290
return raw_signals
291-
291+
292292
def _event_count(self, event_channel_idx, block_index=None, seg_index=None):
293293
timestamps, _, _ = self._get_event_timestamps(block_index, seg_index, event_channel_idx, None, None)
294294
return timestamps.size
295-
295+
296296
def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start=None, t_stop=None):
297297
timestamps, durations, labels = [], None, []
298-
info = self.signals_info_dict[0, 'nidq'] # There are no events that are not in the nidq stream
299-
dig_ch = info['digital_channels']
298+
info = self.signals_info_dict[0, "nidq"] # There are no events that are not in the nidq stream
299+
dig_ch = info["digital_channels"]
300300
if len(dig_ch) > 0:
301301
event_data = self._events_memmap
302302
channel = dig_ch[event_channel_index]
303-
ch_idx = 7 - int(channel[2:]) # They are in the reverse order
304-
this_stream = event_data[:,ch_idx]
305-
this_rising = np.where(np.diff(this_stream)==1)[0] + 1
306-
this_falling = np.where(np.diff(this_stream)==255)[0] + 1 # because the data is in unsigned 8 bit, -1 = 255!
303+
ch_idx = 7 - int(channel[2:]) # They are in the reverse order
304+
this_stream = event_data[:, ch_idx]
305+
this_rising = np.where(np.diff(this_stream) == 1)[0] + 1
306+
this_falling = (
307+
np.where(np.diff(this_stream) == 255)[0] + 1
308+
) # because the data is in unsigned 8 bit, -1 = 255!
307309
if len(this_rising) > 0:
308310
timestamps.extend(this_rising)
309-
labels.extend([channel + ' ON']*len(this_rising))
311+
labels.extend([channel + " ON"] * len(this_rising))
310312
if len(this_falling) > 0:
311313
timestamps.extend(this_falling)
312-
labels.extend([channel + ' OFF']*len(this_falling))
313-
return np.asarray(timestamps), np.asarray(durations), np.asarray(labels)
314+
labels.extend([channel + " OFF"] * len(this_falling))
315+
return np.asarray(timestamps), np.asarray(durations), np.asarray(labels)
314316

315317
def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index):
316-
info = self.signals_info_dict[0, 'nidq'] # There are no events that are not in the nidq stream
318+
info = self.signals_info_dict[0, "nidq"] # There are no events that are not in the nidq stream
317319
if not self._use_direct_evt_timestamps:
318-
event_times = event_timestamps.astype(dtype) / float(info['sampling_rate'])
319-
else: # Does this ever happen?
320+
event_times = event_timestamps.astype(dtype) / float(info["sampling_rate"])
321+
else: # Does this ever happen?
320322
event_times = event_timestamps.astype(dtype)
321323
return event_times
322324

@@ -543,26 +545,26 @@ def extract_stream_info(meta_file, meta):
543545
info["sampling_rate"] = float(meta[k])
544546
info["num_chan"] = num_chan
545547

546-
info['sample_length'] = int(meta['fileSizeBytes']) // 2 // num_chan
547-
info['gate_num'] = gate_num
548-
info['trigger_num'] = trigger_num
549-
info['device'] = device
550-
info['stream_kind'] = stream_kind
551-
info['stream_name'] = stream_name
552-
info['units'] = units
553-
info['channel_names'] = [txt.split(';')[0] for txt in meta['snsChanMap']]
554-
info['channel_gains'] = channel_gains
555-
info['channel_offsets'] = np.zeros(info['num_chan'])
556-
info['has_sync_trace'] = has_sync_trace
557-
558-
if 'nidq' in device:
559-
info['digital_channels'] = []
560-
info['analog_channels'] = [channel for channel in info['channel_names'] if not channel.startswith('XD')]
548+
info["sample_length"] = int(meta["fileSizeBytes"]) // 2 // num_chan
549+
info["gate_num"] = gate_num
550+
info["trigger_num"] = trigger_num
551+
info["device"] = device
552+
info["stream_kind"] = stream_kind
553+
info["stream_name"] = stream_name
554+
info["units"] = units
555+
info["channel_names"] = [txt.split(";")[0] for txt in meta["snsChanMap"]]
556+
info["channel_gains"] = channel_gains
557+
info["channel_offsets"] = np.zeros(info["num_chan"])
558+
info["has_sync_trace"] = has_sync_trace
559+
560+
if "nidq" in device:
561+
info["digital_channels"] = []
562+
info["analog_channels"] = [channel for channel in info["channel_names"] if not channel.startswith("XD")]
561563
# Digital/event channels are encoded within the digital word, so that will need more handling
562-
for item in meta['niXDChans1'].split(','):
563-
if ':' in item:
564-
start, end = map(int, item.split(':'))
565-
info['digital_channels'].extend([f"XD{i}" for i in range(start, end+1)])
564+
for item in meta["niXDChans1"].split(","):
565+
if ":" in item:
566+
start, end = map(int, item.split(":"))
567+
info["digital_channels"].extend([f"XD{i}" for i in range(start, end + 1)])
566568
else:
567-
info['digital_channels'].append(f"XD{int(item)}")
569+
info["digital_channels"].append(f"XD{int(item)}")
568570
return info

neo/test/rawiotest/test_spikeglxrawio.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
99
import numpy as np
1010

11+
1112
class TestSpikeGLXRawIO(BaseTestRawIO, unittest.TestCase):
1213
rawioclass = SpikeGLXRawIO
1314
entities_to_download = ["spikeglx"]
@@ -89,16 +90,17 @@ def test_nidq_digital_channel(self):
8990
rawio_digital = SpikeGLXRawIO(self.get_local_path("spikeglx/DigitalChannelTest_g0"))
9091
rawio_digital.parse_header()
9192
# This data should have 8 event channels
92-
assert(np.shape(rawio_digital.header['event_channels'])[0] == 8)
93+
assert np.shape(rawio_digital.header["event_channels"])[0] == 8
9394

9495
# Channel 0 in this data will have sync pulses at 1 Hz, let's confirm that
9596
all_events = rawio_digital.get_event_timestamps(0, 0, 0)
96-
on_events = np.where(all_events[2] == 'XD0 ON')
97+
on_events = np.where(all_events[2] == "XD0 ON")
9798
on_ts = all_events[0][on_events]
9899
on_ts_scaled = rawio_digital.rescale_event_timestamp(on_ts)
99100
on_diff = np.diff(on_ts_scaled)
100101
atol = 0.001
101-
assert np.allclose(on_diff, 1, atol=atol)
102+
assert np.allclose(on_diff, 1, atol=atol)
103+
102104

103105
if __name__ == "__main__":
104106
unittest.main()

0 commit comments

Comments
 (0)