Skip to content

Commit e14c956

Browse files
authored
Merge pull request #1547 from h-mayorquin/cimprove_plexon_stream_ids
Refactor plexon rawio to have same ids as plexon2
2 parents 32ac313 + ecc3c1e commit e14c956

File tree

2 files changed

+72
-58
lines changed

2 files changed

+72
-58
lines changed

neo/rawio/plexon2rawio/plexon2rawio.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class Plexon2RawIO(BaseRawIO):
5353
pl2_dll_file_path: str | Path | None, default: None
5454
The path to the necessary dll for loading pl2 files
5555
If None will find correct dll for architecture and if it does not exist will download it
56-
reading_attempts: int, default: 15
56+
reading_attempts: int, default: 25
5757
Number of attempts to read the file before raising an error
5858
This opening process is somewhat unreliable and might fail occasionally. Adjust this higher
5959
if you encounter problems in opening the file.
@@ -92,7 +92,7 @@ class Plexon2RawIO(BaseRawIO):
9292
extensions = ["pl2"]
9393
rawmode = "one-file"
9494

95-
def __init__(self, filename, pl2_dll_file_path=None, reading_attempts=15):
95+
def __init__(self, filename, pl2_dll_file_path=None, reading_attempts=25):
9696

9797
# signals, event and spiking data will be cached
9898
# cached signal data can be cleared using `clear_analogsignal_cache()()`
@@ -196,6 +196,7 @@ def _parse_header(self):
196196
"FP": "FPl-Low Pass Filtered",
197197
"SP": "SPKC-High Pass Filtered",
198198
"AI": "AI-Auxiliary Input",
199+
"AIF": "AIF-Auxiliary Input Filtered",
199200
}
200201

201202
unique_stream_ids = np.unique(signal_channels["stream_id"])
@@ -209,17 +210,17 @@ def _parse_header(self):
209210

210211
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
211212

212-
self.stream_id_samples = {}
213-
self.stream_index_to_stream_id = {}
213+
self._stream_id_samples = {}
214+
self._stream_index_to_stream_id = {}
214215
for stream_index, stream_id in enumerate(signal_streams["id"]):
215216
# Keep a mapping from stream_index to stream_id
216-
self.stream_index_to_stream_id[stream_index] = stream_id
217+
self._stream_index_to_stream_id[stream_index] = stream_id
217218

218219
# We extract the number of samples for each stream
219220
mask = signal_channels["stream_id"] == stream_id
220221
signal_num_samples = np.unique(channel_num_samples[mask])
221222
assert signal_num_samples.size == 1, "All channels in a stream must have the same number of samples"
222-
self.stream_id_samples[stream_id] = signal_num_samples[0]
223+
self._stream_id_samples[stream_id] = signal_num_samples[0]
223224

224225
# pre-loading spike channel_data for later usage
225226
self._spike_channel_cache = {}
@@ -386,8 +387,8 @@ def _segment_t_stop(self, block_index, seg_index):
386387
return float(end_time / self.pl2reader.pl2_file_info.m_TimestampFrequency)
387388

388389
def _get_signal_size(self, block_index, seg_index, stream_index):
389-
stream_id = self.stream_index_to_stream_id[stream_index]
390-
num_samples = int(self.stream_id_samples[stream_id])
390+
stream_id = self._stream_index_to_stream_id[stream_index]
391+
num_samples = int(self._stream_id_samples[stream_id])
391392
return num_samples
392393

393394
def _get_signal_t_start(self, block_index, seg_index, stream_index):

neo/rawio/plexonrawio.py

Lines changed: 63 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
_event_channel_dtype,
4444
)
4545

46+
from neo.core.baseneo import NeoReadWriteError
4647

4748
class PlexonRawIO(BaseRawIO):
4849
extensions = ["plx"]
@@ -230,9 +231,19 @@ def _parse_header(self):
230231
self._data_blocks[bl_type][chan_id] = data_block
231232

232233
# signals channels
233-
sig_channels = []
234-
all_sig_length = []
235234
source_id = []
235+
236+
# Scanning sources and populating signal channels at the same time. Sources have to have
237+
# same sampling rate and number of samples to belong to one stream.
238+
signal_channels = []
239+
channel_num_samples = []
240+
241+
# We will build the stream ids based on the channel prefixes
242+
# The channel prefixes are the first characters of the channel names which have the following format:
243+
# WB{number}, FPX{number}, SPKCX{number}, AI{number}, etc
244+
# We will extract the prefix and use it as stream id
245+
regex_prefix_pattern = r"^\D+" # Match any non-digit character at the beginning of the string
246+
236247
if self.progress_bar:
237248
chan_loop = trange(nb_sig_chan, desc="Parsing signal channels", leave=True)
238249
else:
@@ -245,7 +256,7 @@ def _parse_header(self):
245256
if length == 0:
246257
continue # channel not added
247258
source_id.append(h["SrcId"])
248-
all_sig_length.append(length)
259+
channel_num_samples.append(length)
249260
sampling_rate = float(h["ADFreq"])
250261
sig_dtype = "int16"
251262
units = "" # I don't know units
@@ -258,61 +269,60 @@ def _parse_header(self):
258269
0.5 * (2 ** global_header["BitsPerSpikeSample"]) * h["Gain"] * h["PreampGain"]
259270
)
260271
offset = 0.0
261-
stream_id = "0" # This is overwritten later
262-
sig_channels.append((name, str(chan_id), sampling_rate, sig_dtype, units, gain, offset, stream_id))
272+
channel_prefix = re.match(regex_prefix_pattern, name).group(0)
273+
stream_id = channel_prefix
274+
275+
signal_channels.append((name, str(chan_id), sampling_rate, sig_dtype, units, gain, offset, stream_id))
263276

264-
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
277+
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
265278

266-
if sig_channels.size == 0:
279+
if signal_channels.size == 0:
267280
signal_streams = np.array([], dtype=_signal_stream_dtype)
268281

269282
else:
270283
# Detect streams
271-
all_sig_length = np.asarray(all_sig_length)
272-
273-
# names are WB{number}, FPX{number}, SPKCX{number}, AI{number}, etc
274-
pattern = r"^\D+" # Match any non-digit character at the beginning of the string
275-
channels_prefixes = np.asarray([re.match(pattern, name).group(0) for name in sig_channels["name"]])
276-
buffer_stream_groups = set(zip(channels_prefixes, sig_channels["sampling_rate"], all_sig_length))
277-
278-
# There are explanations of the streams based on channel names
279-
# provided by a Plexon Engineer, see here:
284+
channel_num_samples = np.asarray(channel_num_samples)
285+
# We are using channel prefixes as stream_ids
286+
# The meaning of the channel prefixes was provided by a Plexon Engineer, see here:
280287
# https://github.com/NeuralEnsemble/python-neo/pull/1495#issuecomment-2184256894
281-
channel_prefix_to_stream_name = {
288+
stream_id_to_stream_name = {
282289
"WB": "WB-Wideband",
283-
"FP": "FPl-Low Pass Filtered ",
290+
"FP": "FPl-Low Pass Filtered",
284291
"SP": "SPKC-High Pass Filtered",
285292
"AI": "AI-Auxiliary Input",
293+
"AIF": "AIF-Auxiliary Input Filtered",
286294
}
287295

288-
# Using a mapping to ensure consistent order of stream_index
289-
channel_prefix_to_stream_id = {
290-
"WB": "0",
291-
"FP": "1",
292-
"SP": "2",
293-
"AI": "3",
294-
}
295-
296+
unique_stream_ids = np.unique(signal_channels["stream_id"])
296297
signal_streams = []
297-
self._signal_length = {}
298-
self._sig_sampling_rate = {}
299-
300-
for stream_index, (channel_prefix, sr, length) in enumerate(buffer_stream_groups):
301-
# The users of plexon can modify the prefix of the channel names (e.g. `my_prefix` instead of `WB`). This is not common but in that case
302-
# We assign the channel_prefix both as stream_name and stream_id
303-
stream_name = channel_prefix_to_stream_name.get(channel_prefix, channel_prefix)
304-
stream_id = channel_prefix_to_stream_id.get(channel_prefix, channel_prefix)
305-
306-
mask = (sig_channels["sampling_rate"] == sr) & (all_sig_length == length)
307-
sig_channels["stream_id"][mask] = stream_id
308-
309-
self._sig_sampling_rate[stream_index] = sr
310-
self._signal_length[stream_index] = length
311-
298+
for stream_id in unique_stream_ids:
299+
# We are using the channel prefixes as ids
300+
# The users of plexon can modify the prefix of the channel names (e.g. `my_prefix` instead of `WB`).
301+
# In that case we use the channel prefix both as stream id and name
302+
stream_name = stream_id_to_stream_name.get(stream_id, stream_id)
312303
signal_streams.append((stream_name, stream_id))
313304

314305
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
315306

307+
self._stream_id_samples = {}
308+
self._stream_id_sampling_frequency = {}
309+
self._stream_index_to_stream_id = {}
310+
for stream_index, stream_id in enumerate(signal_streams["id"]):
311+
# Keep a mapping from stream_index to stream_id
312+
self._stream_index_to_stream_id[stream_index] = stream_id
313+
314+
mask = signal_channels["stream_id"] == stream_id
315+
316+
signal_num_samples = np.unique(channel_num_samples[mask])
317+
if signal_num_samples.size > 1:
318+
raise NeoReadWriteError(f"Channels in stream {stream_id} don't have the same number of samples")
319+
self._stream_id_samples[stream_id] = signal_num_samples[0]
320+
321+
signal_sampling_frequency = np.unique(signal_channels[mask]["sampling_rate"])
322+
if signal_sampling_frequency.size > 1:
323+
raise NeoReadWriteError(f"Channels in stream {stream_id} don't have the same sampling frequency")
324+
self._stream_id_sampling_frequency[stream_id] = signal_sampling_frequency[0]
325+
316326
self._global_ssampling_rate = global_header["ADFrequency"]
317327

318328
# Determine number of units per channels
@@ -374,7 +384,7 @@ def _parse_header(self):
374384
"nb_block": 1,
375385
"nb_segment": [1],
376386
"signal_streams": signal_streams,
377-
"signal_channels": sig_channels,
387+
"signal_channels": signal_channels,
378388
"spike_channels": spike_channels,
379389
"event_channels": event_channels,
380390
}
@@ -392,28 +402,31 @@ def _segment_t_start(self, block_index, seg_index):
392402

393403
def _segment_t_stop(self, block_index, seg_index):
394404
t_stop = float(self._last_timestamps) / self._global_ssampling_rate
395-
if hasattr(self, "_signal_length"):
396-
for stream_index in self._signal_length.keys():
397-
t_stop_sig = self._signal_length[stream_index] / self._sig_sampling_rate[stream_index]
405+
if hasattr(self, "__stream_id_samples"):
406+
for stream_id in self._stream_id_samples.keys():
407+
t_stop_sig = self._stream_id_samples[stream_id] / self._stream_id_sampling_frequency[stream_id]
398408
t_stop = max(t_stop, t_stop_sig)
399409
return t_stop
400410

401411
def _get_signal_size(self, block_index, seg_index, stream_index):
402-
return self._signal_length[stream_index]
412+
stream_id = self._stream_index_to_stream_id[stream_index]
413+
return self._stream_id_samples[stream_id]
403414

404415
def _get_signal_t_start(self, block_index, seg_index, stream_index):
405416
return 0.0
406417

407418
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes):
419+
signal_channels = self.header["signal_channels"]
420+
signal_streams = self.header["signal_streams"]
421+
stream_id = signal_streams[stream_index]["id"]
422+
408423
if i_start is None:
409424
i_start = 0
410425
if i_stop is None:
411-
i_stop = self._signal_length[stream_index]
426+
i_stop = self._stream_id_samples[stream_id]
427+
412428

413-
signal_channels = self.header["signal_channels"]
414-
signal_streams = self.header["signal_streams"]
415429

416-
stream_id = signal_streams[stream_index]["id"]
417430
mask = signal_channels["stream_id"] == stream_id
418431
signal_channels = signal_channels[mask]
419432
if channel_indexes is not None:

0 commit comments

Comments
 (0)