Skip to content

Commit b409753

Browse files
committed
draft
1 parent 44d15cc commit b409753

File tree

3 files changed

+237
-19
lines changed

3 files changed

+237
-19
lines changed

neo/rawio/neuralynxrawio/neuralynxrawio.py

Lines changed: 128 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,20 @@ class NeuralynxRawIO(BaseRawIO):
134134
("samples", "int16", NcsSection._RECORD_SIZE),
135135
]
136136

137+
# Filter parameter keys used for stream differentiation
138+
_filter_keys = [
139+
"DSPLowCutFilterEnabled",
140+
"DspLowCutFrequency",
141+
"DspLowCutFilterType",
142+
"DspLowCutNumTaps",
143+
"DSPHighCutFilterEnabled",
144+
"DspHighCutFrequency",
145+
"DspHighCutFilterType",
146+
"DspHighCutNumTaps",
147+
"DspDelayCompensation",
148+
"DspFilterDelay_µs",
149+
]
150+
137151
def __init__(
138152
self,
139153
dirname="",
@@ -189,6 +203,61 @@ def __init__(
189203
def _source_name(self):
190204
return self.dirname
191205

206+
def _build_stream_key(self, header_info, chan_index, gain):
207+
"""
208+
Build stream key based on acquisition parameters only.
209+
210+
Stream keys are used to group channels that share the same acquisition
211+
configuration. Channels with the same stream key will be placed in the
212+
same stream and can be read together.
213+
214+
Parameters
215+
----------
216+
header_info : dict
217+
Header information from NlxHeader
218+
chan_index : int
219+
Channel index for multi-channel parameters
220+
gain : float
221+
Channel gain value (bit_to_microVolt)
222+
223+
Returns
224+
-------
225+
tuple
226+
Hashable stream key containing acquisition parameters:
227+
(sampling_rate, input_range, gain, input_inverted, filter_params_tuple)
228+
"""
229+
# Core acquisition parameters (already normalized by NlxHeader)
230+
sampling_rate = float(header_info["sampling_rate"])
231+
232+
# Get InputRange - could be int (single-channel) or list (multi-channel)
233+
input_range = header_info.get("InputRange")
234+
if isinstance(input_range, list):
235+
# Multi-channel file: get value for this channel
236+
input_range = input_range[chan_index] if chan_index < len(input_range) else input_range[0]
237+
# Already converted to int by NlxHeader._normalize_types()
238+
239+
gain = float(gain)
240+
241+
input_inverted = header_info.get("input_inverted", False)
242+
243+
# Filter parameters (already normalized by NlxHeader)
244+
filter_params = []
245+
for key in self._filter_keys:
246+
value = header_info.get(key)
247+
if value is not None:
248+
filter_params.append((key, value))
249+
250+
# Create hashable stream key
251+
stream_key = (
252+
sampling_rate,
253+
input_range,
254+
gain,
255+
input_inverted,
256+
tuple(sorted(filter_params)),
257+
)
258+
259+
return stream_key
260+
192261
def _parse_header(self):
193262

194263
stream_channels = []
@@ -268,26 +337,30 @@ def _parse_header(self):
268337

269338
chan_uid = (chan_name, str(chan_id))
270339
if ext == "ncs":
271-
file_mmap = self._get_file_map(filename)
272-
n_packets = copy.copy(file_mmap.shape[0])
273-
if n_packets:
274-
t_start = copy.copy(file_mmap[0][0])
275-
else: # empty file
276-
t_start = 0
277-
stream_prop = (float(info["sampling_rate"]), int(n_packets), float(t_start))
278-
if stream_prop not in stream_props:
279-
stream_props[stream_prop] = {"stream_id": len(stream_props), "filenames": [filename]}
340+
# Calculate gain for this channel
341+
gain = info["bit_to_microVolt"][idx]
342+
if info.get("input_inverted", False):
343+
gain *= -1
344+
345+
# Build stream key from acquisition parameters only
346+
stream_key = self._build_stream_key(info, idx, gain)
347+
348+
if stream_key not in stream_props:
349+
stream_props[stream_key] = {
350+
"stream_id": len(stream_props),
351+
"filenames": [filename],
352+
"channels": set(),
353+
}
280354
else:
281-
stream_props[stream_prop]["filenames"].append(filename)
282-
stream_id = stream_props[stream_prop]["stream_id"]
355+
stream_props[stream_key]["filenames"].append(filename)
356+
357+
stream_id = stream_props[stream_key]["stream_id"]
358+
stream_props[stream_key]["channels"].add((chan_name, str(chan_id)))
283359
# @zach @ramon : we need to discuss this split by channel buffer
284360
buffer_id = ""
285361

286362
# a sampled signal channel
287363
units = "uV"
288-
gain = info["bit_to_microVolt"][idx]
289-
if info.get("input_inverted", False):
290-
gain *= -1
291364
offset = 0.0
292365
signal_channels.append(
293366
(
@@ -392,14 +465,50 @@ def _parse_header(self):
392465
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
393466

394467
if signal_channels.size > 0:
395-
# ordering streams according from high to low sampling rates
396-
stream_props = {k: stream_props[k] for k in sorted(stream_props, reverse=True)}
397-
stream_names = [f"Stream (rate,#packet,t0): {sp}" for sp in stream_props]
398-
stream_ids = [stream_prop["stream_id"] for stream_prop in stream_props.values()]
399-
buffer_ids = ["" for sp in stream_props]
468+
# Build filter configuration registry
469+
filter_configs = {} # filter_params_tuple -> filter_id
470+
_filter_configurations = {} # filter_id -> filter parameters dict
471+
472+
for stream_key, stream_info in stream_props.items():
473+
# Extract filter parameters from stream_key
474+
# stream_key = (sampling_rate, input_range, gain, input_inverted, filter_params_tuple)
475+
sampling_rate, input_range, gain, input_inverted, filter_params_tuple = stream_key
476+
477+
# Assign filter ID (deduplicated by filter_params_tuple)
478+
if filter_params_tuple not in filter_configs:
479+
filter_id = len(filter_configs)
480+
filter_configs[filter_params_tuple] = filter_id
481+
# Convert filter_params_tuple to dict for storage
482+
_filter_configurations[filter_id] = dict(filter_params_tuple)
483+
484+
# Store filter configurations as private instance attribute
485+
self._filter_configurations = _filter_configurations
486+
487+
# Order streams by sampling rate (high to low)
488+
ordered_stream_keys = sorted(stream_props.keys(), reverse=True, key=lambda x: x[0])
489+
490+
stream_names = []
491+
stream_ids = []
492+
buffer_ids = []
493+
494+
for stream_key in ordered_stream_keys:
495+
stream_info = stream_props[stream_key]
496+
stream_id = stream_info["stream_id"]
497+
498+
# Unpack stream_key and format stream name
499+
sampling_rate, input_range, gain, input_inverted, filter_params_tuple = stream_key
500+
filter_id = filter_configs[filter_params_tuple]
501+
voltage_mv = int(input_range / 1000) if input_range is not None else 0
502+
stream_name = f"stream{stream_id}_{int(sampling_rate)}Hz_{voltage_mv}mVRange_f{filter_id}"
503+
504+
stream_names.append(stream_name)
505+
stream_ids.append(stream_id)
506+
buffer_ids.append("")
507+
400508
signal_streams = list(zip(stream_names, stream_ids, buffer_ids))
401509
else:
402510
signal_streams = []
511+
self._filter_configurations = {}
403512
signal_buffers = np.array([], dtype=_signal_buffer_dtype)
404513
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
405514

neo/rawio/neuralynxrawio/nlxheader.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ def __init__(self, filename, props_only=False):
199199
if not props_only:
200200
self._setTimeDate(txt_header)
201201

202+
# Normalize all types to proper Python types
203+
self._normalize_types()
204+
202205
@staticmethod
203206
def get_text_header(filename):
204207
"""
@@ -349,6 +352,70 @@ def _setTimeDate(self, txt_header):
349352
dt2 = sr.groupdict()
350353
self["recording_closed"] = dateutil.parser.parse(f"{dt2['date']} {dt2['time']}")
351354

355+
def _normalize_types(self):
356+
"""
357+
Convert all header values to proper Python types.
358+
359+
This ensures that:
360+
- Boolean strings ('True', 'False', 'Enabled', 'Disabled') become Python bools
361+
- Numeric strings ('0.1', '8000') become Python floats/ints
362+
- Single-element lists are extracted to scalars (for single-channel files)
363+
364+
This normalization makes the header values directly usable for
365+
stream identification without additional conversion in NeuralynxRawIO.
366+
"""
367+
368+
# Convert boolean strings to actual booleans
369+
bool_keys = [
370+
'DSPLowCutFilterEnabled',
371+
'DSPHighCutFilterEnabled',
372+
'DspDelayCompensation',
373+
]
374+
375+
for key in bool_keys:
376+
if key in self and isinstance(self[key], str):
377+
if self[key] in ('True', 'Enabled'):
378+
self[key] = True
379+
elif self[key] in ('False', 'Disabled'):
380+
self[key] = False
381+
382+
# Convert numeric strings to numbers
383+
numeric_keys = [
384+
'DspLowCutFrequency',
385+
'DspHighCutFrequency',
386+
'DspLowCutNumTaps',
387+
'DspHighCutNumTaps',
388+
]
389+
390+
for key in numeric_keys:
391+
if key in self and isinstance(self[key], str):
392+
try:
393+
# Try int first
394+
if '.' not in self[key]:
395+
self[key] = int(self[key])
396+
else:
397+
self[key] = float(self[key])
398+
except ValueError:
399+
# Keep as string if conversion fails
400+
pass
401+
402+
# Handle DspFilterDelay_µs (could be string or already converted)
403+
delay_key = 'DspFilterDelay_µs'
404+
if delay_key in self and isinstance(self[delay_key], str):
405+
try:
406+
self[delay_key] = int(self[delay_key])
407+
except ValueError:
408+
pass
409+
410+
# Extract single-channel InputRange from list
411+
# For multi-channel files, keep as list
412+
# For single-channel files, extract the single value
413+
if 'InputRange' in self and isinstance(self['InputRange'], list):
414+
if len(self['InputRange']) == 1:
415+
# Single channel file: extract the value
416+
self['InputRange'] = self['InputRange'][0]
417+
# else: multi-channel, keep as list
418+
352419
def type_of_recording(self):
353420
"""
354421
Determines type of recording in Ncs file with this header.

neo/test/rawiotest/test_neuralynxrawio.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,48 @@ def test_directory_in_data_folder(self):
213213
self.assertEqual(len(rawio.header["spike_channels"]), 8)
214214
self.assertEqual(len(rawio.header["event_channels"]), 2)
215215

216+
def test_two_streams_different_header_encoding(self):
217+
"""
218+
Test that streams are correctly differentiated based on filter parameters.
219+
This dataset contains eye-tracking and ephys channels with different filter settings.
220+
"""
221+
from pathlib import Path
222+
223+
# Get the path using the same machinery as other tests
224+
dname = self.get_local_path("neuralynx/two_streams_different_header_encoding")
225+
226+
# Test with Path object (as shown in user's notebook)
227+
rawio = NeuralynxRawIO(dirname=Path(dname))
228+
rawio.parse_header()
229+
230+
# Should have 2 streams due to different filter configurations
231+
self.assertEqual(rawio.signal_streams_count(), 2)
232+
233+
# Check stream names follow the new naming convention
234+
stream_names = [rawio.header["signal_streams"][i][0] for i in range(rawio.signal_streams_count())]
235+
236+
# Stream names should include sampling rate (Hz), voltage range (mV), and filter ID
237+
for stream_name in stream_names:
238+
self.assertRegex(stream_name, r"stream\d+_\d+Hz_\d+mVRange_f\d+")
239+
240+
# Verify we have the expected streams:
241+
# - Eye-tracking channels (CSC145, CSC146): 32000Hz, 100mV range, low-cut disabled
242+
# - Ephys channel (csc23_100): 32000Hz, 1mV range, low-cut enabled
243+
expected_names = {"stream0_32000Hz_100mVRange_f0", "stream1_32000Hz_1mVRange_f1"}
244+
self.assertEqual(set(stream_names), expected_names)
245+
246+
# Verify filter configurations are stored privately
247+
self.assertTrue(hasattr(rawio, "_filter_configurations"))
248+
self.assertEqual(len(rawio._filter_configurations), 2)
249+
250+
# Verify filter 0 (eye-tracking): low-cut disabled
251+
filter_0 = rawio._filter_configurations[0]
252+
self.assertFalse(filter_0.get("DSPLowCutFilterEnabled", True))
253+
254+
# Verify filter 1 (ephys): low-cut enabled
255+
filter_1 = rawio._filter_configurations[1]
256+
self.assertTrue(filter_1.get("DSPLowCutFilterEnabled", False))
257+
216258

217259
class TestNcsRecordingType(BaseTestRawIO, unittest.TestCase):
218260
"""

0 commit comments

Comments
 (0)