Skip to content

Commit 039144d

Browse files
authored
Merge branch 'master' into neuronexus
2 parents ccc01f9 + 3e1716f commit 039144d

File tree

5 files changed

+126
-67
lines changed

5 files changed

+126
-67
lines changed

neo/rawio/plexon2rawio/plexon2rawio.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Plexon2RawIO(BaseRawIO):
5353
pl2_dll_file_path: str | Path | None, default: None
5454
The path to the necessary dll for loading pl2 files
5555
If None will find correct dll for architecture and if it does not exist will download it
56-
reading_attempts: int, default: 15
56+
reading_attempts: int, default: 25
5757
Number of attempts to read the file before raising an error
5858
This opening process is somewhat unreliable and might fail occasionally. Adjust this higher
5959
if you encounter problems in opening the file.
@@ -92,7 +92,7 @@ class Plexon2RawIO(BaseRawIO):
9292
extensions = ["pl2"]
9393
rawmode = "one-file"
9494

95-
def __init__(self, filename, pl2_dll_file_path=None, reading_attempts=15):
95+
def __init__(self, filename, pl2_dll_file_path=None, reading_attempts=25):
9696

9797
# signals, event and spiking data will be cached
9898
# cached signal data can be cleared using `clear_analogsignal_cache()()`
@@ -196,6 +196,7 @@ def _parse_header(self):
196196
"FP": "FPl-Low Pass Filtered",
197197
"SP": "SPKC-High Pass Filtered",
198198
"AI": "AI-Auxiliary Input",
199+
"AIF": "AIF-Auxiliary Input Filtered",
199200
}
200201

201202
unique_stream_ids = np.unique(signal_channels["stream_id"])
@@ -209,17 +210,17 @@ def _parse_header(self):
209210

210211
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
211212

212-
self.stream_id_samples = {}
213-
self.stream_index_to_stream_id = {}
213+
self._stream_id_samples = {}
214+
self._stream_index_to_stream_id = {}
214215
for stream_index, stream_id in enumerate(signal_streams["id"]):
215216
# Keep a mapping from stream_index to stream_id
216-
self.stream_index_to_stream_id[stream_index] = stream_id
217+
self._stream_index_to_stream_id[stream_index] = stream_id
217218

218219
# We extract the number of samples for each stream
219220
mask = signal_channels["stream_id"] == stream_id
220221
signal_num_samples = np.unique(channel_num_samples[mask])
221222
assert signal_num_samples.size == 1, "All channels in a stream must have the same number of samples"
222-
self.stream_id_samples[stream_id] = signal_num_samples[0]
223+
self._stream_id_samples[stream_id] = signal_num_samples[0]
223224

224225
# pre-loading spike channel_data for later usage
225226
self._spike_channel_cache = {}
@@ -231,7 +232,14 @@ def _parse_header(self):
231232
if not (schannel_info.m_ChannelEnabled and schannel_info.m_ChannelRecordingEnabled):
232233
continue
233234

234-
for channel_unit_id in range(schannel_info.m_NumberOfUnits):
235+
# In a PL2 spike channel header, the field "m_NumberOfUnits" denotes the number
236+
# of units to which spikes detected on that channel have been assigned. It does
237+
# not account for unsorted spikes, i.e., spikes that have not been assigned to
238+
# a unit. It is Plexon's convention to assign unsorted spikes to channel_unit_id=0,
239+
# and sorted spikes to channel_unit_id's 1, 2, 3...etc. Therefore, for a given value of
240+
# m_NumberOfUnits, there are m_NumberOfUnits+1 channel_unit_ids to consider - 1
241+
# unsorted channel_unit_id (0) + the m_NumberOfUnits sorted channel_unit_ids.
242+
for channel_unit_id in range(schannel_info.m_NumberOfUnits+1):
235243
unit_name = f"{schannel_info.m_Name.decode()}.{channel_unit_id}"
236244
unit_id = f"unit{schannel_info.m_Channel}.{channel_unit_id}"
237245
wf_units = schannel_info.m_Units
@@ -386,8 +394,8 @@ def _segment_t_stop(self, block_index, seg_index):
386394
return float(end_time / self.pl2reader.pl2_file_info.m_TimestampFrequency)
387395

388396
def _get_signal_size(self, block_index, seg_index, stream_index):
389-
stream_id = self.stream_index_to_stream_id[stream_index]
390-
num_samples = int(self.stream_id_samples[stream_id])
397+
stream_id = self._stream_index_to_stream_id[stream_index]
398+
num_samples = int(self._stream_id_samples[stream_id])
391399
return num_samples
392400

393401
def _get_signal_t_start(self, block_index, seg_index, stream_index):

neo/rawio/plexonrawio.py

Lines changed: 63 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
_event_channel_dtype,
4444
)
4545

46+
from neo.core.baseneo import NeoReadWriteError
4647

4748
class PlexonRawIO(BaseRawIO):
4849
extensions = ["plx"]
@@ -230,9 +231,19 @@ def _parse_header(self):
230231
self._data_blocks[bl_type][chan_id] = data_block
231232

232233
# signals channels
233-
sig_channels = []
234-
all_sig_length = []
235234
source_id = []
235+
236+
# Scanning sources and populating signal channels at the same time. Sources have to have
237+
# same sampling rate and number of samples to belong to one stream.
238+
signal_channels = []
239+
channel_num_samples = []
240+
241+
# We will build the stream ids based on the channel prefixes
242+
# The channel prefixes are the first characters of the channel names which have the following format:
243+
# WB{number}, FPX{number}, SPKCX{number}, AI{number}, etc
244+
# We will extract the prefix and use it as stream id
245+
regex_prefix_pattern = r"^\D+" # Match any non-digit character at the beginning of the string
246+
236247
if self.progress_bar:
237248
chan_loop = trange(nb_sig_chan, desc="Parsing signal channels", leave=True)
238249
else:
@@ -245,7 +256,7 @@ def _parse_header(self):
245256
if length == 0:
246257
continue # channel not added
247258
source_id.append(h["SrcId"])
248-
all_sig_length.append(length)
259+
channel_num_samples.append(length)
249260
sampling_rate = float(h["ADFreq"])
250261
sig_dtype = "int16"
251262
units = "" # I don't know units
@@ -258,61 +269,60 @@ def _parse_header(self):
258269
0.5 * (2 ** global_header["BitsPerSpikeSample"]) * h["Gain"] * h["PreampGain"]
259270
)
260271
offset = 0.0
261-
stream_id = "0" # This is overwritten later
262-
sig_channels.append((name, str(chan_id), sampling_rate, sig_dtype, units, gain, offset, stream_id))
272+
channel_prefix = re.match(regex_prefix_pattern, name).group(0)
273+
stream_id = channel_prefix
274+
275+
signal_channels.append((name, str(chan_id), sampling_rate, sig_dtype, units, gain, offset, stream_id))
263276

264-
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
277+
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
265278

266-
if sig_channels.size == 0:
279+
if signal_channels.size == 0:
267280
signal_streams = np.array([], dtype=_signal_stream_dtype)
268281

269282
else:
270283
# Detect streams
271-
all_sig_length = np.asarray(all_sig_length)
272-
273-
# names are WB{number}, FPX{number}, SPKCX{number}, AI{number}, etc
274-
pattern = r"^\D+" # Match any non-digit character at the beginning of the string
275-
channels_prefixes = np.asarray([re.match(pattern, name).group(0) for name in sig_channels["name"]])
276-
buffer_stream_groups = set(zip(channels_prefixes, sig_channels["sampling_rate"], all_sig_length))
277-
278-
# There are explanations of the streams based on channel names
279-
# provided by a Plexon Engineer, see here:
284+
channel_num_samples = np.asarray(channel_num_samples)
285+
# We are using channel prefixes as stream_ids
286+
# The meaning of the channel prefixes was provided by a Plexon Engineer, see here:
280287
# https://github.com/NeuralEnsemble/python-neo/pull/1495#issuecomment-2184256894
281-
channel_prefix_to_stream_name = {
288+
stream_id_to_stream_name = {
282289
"WB": "WB-Wideband",
283-
"FP": "FPl-Low Pass Filtered ",
290+
"FP": "FPl-Low Pass Filtered",
284291
"SP": "SPKC-High Pass Filtered",
285292
"AI": "AI-Auxiliary Input",
293+
"AIF": "AIF-Auxiliary Input Filtered",
286294
}
287295

288-
# Using a mapping to ensure consistent order of stream_index
289-
channel_prefix_to_stream_id = {
290-
"WB": "0",
291-
"FP": "1",
292-
"SP": "2",
293-
"AI": "3",
294-
}
295-
296+
unique_stream_ids = np.unique(signal_channels["stream_id"])
296297
signal_streams = []
297-
self._signal_length = {}
298-
self._sig_sampling_rate = {}
299-
300-
for stream_index, (channel_prefix, sr, length) in enumerate(buffer_stream_groups):
301-
# The users of plexon can modify the prefix of the channel names (e.g. `my_prefix` instead of `WB`). This is not common but in that case
302-
# We assign the channel_prefix both as stream_name and stream_id
303-
stream_name = channel_prefix_to_stream_name.get(channel_prefix, channel_prefix)
304-
stream_id = channel_prefix_to_stream_id.get(channel_prefix, channel_prefix)
305-
306-
mask = (sig_channels["sampling_rate"] == sr) & (all_sig_length == length)
307-
sig_channels["stream_id"][mask] = stream_id
308-
309-
self._sig_sampling_rate[stream_index] = sr
310-
self._signal_length[stream_index] = length
311-
298+
for stream_id in unique_stream_ids:
299+
# We are using the channel prefixes as ids
300+
# The users of plexon can modify the prefix of the channel names (e.g. `my_prefix` instead of `WB`).
301+
# In that case we use the channel prefix both as stream id and name
302+
stream_name = stream_id_to_stream_name.get(stream_id, stream_id)
312303
signal_streams.append((stream_name, stream_id))
313304

314305
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
315306

307+
self._stream_id_samples = {}
308+
self._stream_id_sampling_frequency = {}
309+
self._stream_index_to_stream_id = {}
310+
for stream_index, stream_id in enumerate(signal_streams["id"]):
311+
# Keep a mapping from stream_index to stream_id
312+
self._stream_index_to_stream_id[stream_index] = stream_id
313+
314+
mask = signal_channels["stream_id"] == stream_id
315+
316+
signal_num_samples = np.unique(channel_num_samples[mask])
317+
if signal_num_samples.size > 1:
318+
raise NeoReadWriteError(f"Channels in stream {stream_id} don't have the same number of samples")
319+
self._stream_id_samples[stream_id] = signal_num_samples[0]
320+
321+
signal_sampling_frequency = np.unique(signal_channels[mask]["sampling_rate"])
322+
if signal_sampling_frequency.size > 1:
323+
raise NeoReadWriteError(f"Channels in stream {stream_id} don't have the same sampling frequency")
324+
self._stream_id_sampling_frequency[stream_id] = signal_sampling_frequency[0]
325+
316326
self._global_ssampling_rate = global_header["ADFrequency"]
317327

318328
# Determine number of units per channels
@@ -374,7 +384,7 @@ def _parse_header(self):
374384
"nb_block": 1,
375385
"nb_segment": [1],
376386
"signal_streams": signal_streams,
377-
"signal_channels": sig_channels,
387+
"signal_channels": signal_channels,
378388
"spike_channels": spike_channels,
379389
"event_channels": event_channels,
380390
}
@@ -392,28 +402,31 @@ def _segment_t_start(self, block_index, seg_index):
392402

393403
def _segment_t_stop(self, block_index, seg_index):
394404
t_stop = float(self._last_timestamps) / self._global_ssampling_rate
395-
if hasattr(self, "_signal_length"):
396-
for stream_index in self._signal_length.keys():
397-
t_stop_sig = self._signal_length[stream_index] / self._sig_sampling_rate[stream_index]
405+
if hasattr(self, "__stream_id_samples"):
406+
for stream_id in self._stream_id_samples.keys():
407+
t_stop_sig = self._stream_id_samples[stream_id] / self._stream_id_sampling_frequency[stream_id]
398408
t_stop = max(t_stop, t_stop_sig)
399409
return t_stop
400410

401411
def _get_signal_size(self, block_index, seg_index, stream_index):
402-
return self._signal_length[stream_index]
412+
stream_id = self._stream_index_to_stream_id[stream_index]
413+
return self._stream_id_samples[stream_id]
403414

404415
def _get_signal_t_start(self, block_index, seg_index, stream_index):
405416
return 0.0
406417

407418
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes):
419+
signal_channels = self.header["signal_channels"]
420+
signal_streams = self.header["signal_streams"]
421+
stream_id = signal_streams[stream_index]["id"]
422+
408423
if i_start is None:
409424
i_start = 0
410425
if i_stop is None:
411-
i_stop = self._signal_length[stream_index]
426+
i_stop = self._stream_id_samples[stream_id]
427+
412428

413-
signal_channels = self.header["signal_channels"]
414-
signal_streams = self.header["signal_streams"]
415429

416-
stream_id = signal_streams[stream_index]["id"]
417430
mask = signal_channels["stream_id"] == stream_id
418431
signal_channels = signal_channels[mask]
419432
if channel_indexes is not None:

neo/test/rawiotest/common_rawio_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,15 @@ class BaseTestRawIO:
5959

6060
local_test_dir = get_local_testing_data_folder()
6161

62-
def setUp(self):
62+
@classmethod
63+
def setUpClass(cls):
6364
"""
64-
Set up the test fixture. This is run for every test
65+
Set up the test fixture. This is run once before all tests.
6566
"""
66-
self.shortname = self.rawioclass.__name__.lower().replace("rawio", "")
67+
cls.shortname = cls.rawioclass.__name__.lower().replace("rawio", "")
6768

68-
if HAVE_DATALAD and self.use_network:
69-
for remote_path in self.entities_to_download:
69+
if HAVE_DATALAD and cls.use_network:
70+
for remote_path in cls.entities_to_download:
7071
download_dataset(repo=repo_for_test, remote_path=remote_path)
7172
else:
7273
raise unittest.SkipTest("Requires datalad download of data from the web")

neo/test/rawiotest/test_maxwellrawio.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import unittest
2-
import os
32

43
from neo.rawio.maxwellrawio import MaxwellRawIO, auto_install_maxwell_hdf5_compression_plugin
54
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
@@ -9,6 +8,7 @@ class TestMaxwellRawIO(
98
BaseTestRawIO,
109
unittest.TestCase,
1110
):
11+
1212
rawioclass = MaxwellRawIO
1313
entities_to_download = ["maxwell"]
1414
entities_to_test = files_to_test = [
@@ -18,7 +18,6 @@ class TestMaxwellRawIO(
1818

1919
def setUp(self):
2020
auto_install_maxwell_hdf5_compression_plugin(force_download=False)
21-
BaseTestRawIO.setUp(self)
2221

2322

2423
if __name__ == "__main__":

neo/test/rawiotest/test_plexon2rawio.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
1111

12+
from numpy.testing import assert_equal
1213

1314
try:
1415
from neo.rawio.plexon2rawio.pypl2 import pypl2lib
@@ -25,8 +26,45 @@ class TestPlexon2RawIO(
2526
):
2627
rawioclass = Plexon2RawIO
2728
entities_to_download = ["plexon"]
28-
entities_to_test = ["plexon/4chDemoPL2.pl2"]
29+
entities_to_test = ["plexon/4chDemoPL2.pl2",
30+
"plexon/NC16FPSPKEVT_1m.pl2"
31+
]
32+
2933

34+
def test_check_enabled_flags(self):
35+
"""
36+
This test loads a 1-minute PL2 file with 16 channels' each
37+
of field potential (FP), and spike (SPK) data. The channels
38+
cycle through 4 possible combinations of m_ChannelEnabled
39+
and m_ChannelRecordingEnabled - (True, True), (True, False),
40+
(False, True), and (False, False). With 16 channels for each
41+
source, each combination of flags occurs 4 times. Only the
42+
first combination (True, True) causes data to be recorded to
43+
disk. Therefore, we expect the following channels to be loaded by
44+
Neo: FP01, FP05, FP09, FP13, SPK01, SPK05, SPK09, and SPK13.
45+
46+
Note: the file contains event (EVT) data as well. Although event
47+
channel headers do contain m_ChannelEnabled and m_ChannelRecording-
48+
Enabled flags, the UI for recording PL2 files does not expose any
49+
controls by which these flags can be changed from (True, True).
50+
Therefore, no test for event channels is necessary here.
51+
"""
52+
53+
# Load data from NC16FPSPKEVT_1m.pl2, a 1-minute PL2 recording containing
54+
# 16-channels' each of field potential (FP), spike (SPK), and event (EVT)
55+
# data.
56+
reader = Plexon2RawIO(filename=self.get_local_path("plexon/NC16FPSPKEVT_1m.pl2"))
57+
reader.parse_header()
58+
59+
# Check that the names of the loaded signal channels match what we expect
60+
signal_channel_names = reader.header["signal_channels"]["name"].tolist()
61+
expected_signal_channel_names = ["FP01","FP05","FP09","FP13"]
62+
assert_equal(signal_channel_names, expected_signal_channel_names)
63+
64+
# Check that the names of the loaded spike channels match what we expect
65+
spike_channel_names = reader.header["spike_channels"]["name"].tolist()
66+
expected_spike_channel_names = ["SPK01.0", "SPK05.0", "SPK09.0", "SPK13.0"]
67+
assert_equal(spike_channel_names, expected_spike_channel_names)
3068

3169
if __name__ == "__main__":
3270
unittest.main()

0 commit comments

Comments
 (0)