Skip to content

Commit 6c17e56

Browse files
committed
fix this
2 parents 959cef8 + 72ec76a commit 6c17e56

File tree

2 files changed

+88
-36
lines changed

2 files changed

+88
-36
lines changed

neo/rawio/blackrockrawio.py

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,25 @@ def _parse_header(self):
318318
self._nsx_basic_header = {}
319319
self._nsx_ext_header = {}
320320
self._nsx_data_header = {}
321+
self._nsx_sampling_frequency = {}
321322

323+
# Read headers
322324
for nsx_nb in self._avail_nsx:
323325
spec_version = self._nsx_spec[nsx_nb] = self._extract_nsx_file_spec(nsx_nb)
324326
# read nsx headers
325327
nsx_header_reader = self._nsx_header_reader[spec_version]
326328
self._nsx_basic_header[nsx_nb], self._nsx_ext_header[nsx_nb] = nsx_header_reader(nsx_nb)
329+
330+
# The Blackrock defines period as the number of 1/30_000 seconds between data points
331+
# E.g. it is 1 for 30_000, 3 for 10_000, etc
332+
nsx_period = self._nsx_basic_header[nsx_nb]["period"]
333+
sampling_rate = 30_000.0 / nsx_period
334+
self._nsx_sampling_frequency[nsx_nb] = float(sampling_rate)
327335

336+
337+
# Parase data packages
338+
for nsx_nb in self._avail_nsx:
339+
328340
# The only way to know if it is the Precision Time Protocol of file spec 3.0
329341
# is to check for nanosecond timestamp resolution.
330342
is_ptp_variant = (
@@ -381,7 +393,8 @@ def _parse_header(self):
381393
self._match_nsx_and_nev_segment_ids(nsx_nb)
382394

383395
self.nsx_datas = {}
384-
self.sig_sampling_rates = {}
396+
# Keep public attribute for backward compatibility but let's use the private one and maybe deprecate this at some point
397+
self.sig_sampling_rates = {nsx_number: self._nsx_sampling_frequency[nsx_number] for nsx_number in self.nsx_to_load}
385398
if len(self.nsx_to_load) > 0:
386399
for nsx_nb in self.nsx_to_load:
387400
basic_header = self._nsx_basic_header[nsx_nb]
@@ -398,8 +411,7 @@ def _parse_header(self):
398411
_data_reader_fun = self._nsx_data_reader[spec_version]
399412
self.nsx_datas[nsx_nb] = _data_reader_fun(nsx_nb)
400413

401-
sr = float(self.main_sampling_rate / basic_header["period"])
402-
self.sig_sampling_rates[nsx_nb] = sr
414+
sr = self._nsx_sampling_frequency[nsx_nb]
403415

404416
if spec_version in ["2.2", "2.3", "3.0"]:
405417
ext_header = self._nsx_ext_header[nsx_nb]
@@ -468,7 +480,7 @@ def _parse_header(self):
468480
length = self.nsx_datas[nsx_nb][data_bl].shape[0]
469481
if self._nsx_data_header[nsx_nb] is None:
470482
t_start = 0.0
471-
t_stop = max(t_stop, length / self.sig_sampling_rates[nsx_nb])
483+
t_stop = max(t_stop, length / self._nsx_sampling_frequency[nsx_nb])
472484
else:
473485
timestamps = self._nsx_data_header[nsx_nb][data_bl]["timestamp"]
474486
if hasattr(timestamps, "size") and timestamps.size == length:
@@ -477,7 +489,7 @@ def _parse_header(self):
477489
t_stop = max(t_stop, timestamps[-1] / ts_res + sec_per_samp)
478490
else:
479491
t_start = timestamps / ts_res
480-
t_stop = max(t_stop, t_start + length / self.sig_sampling_rates[nsx_nb])
492+
t_stop = max(t_stop, t_start + length / self._nsx_sampling_frequency[nsx_nb])
481493
self._sigs_t_starts[nsx_nb].append(t_start)
482494

483495
if self._avail_files["nev"]:
@@ -1036,43 +1048,83 @@ def _read_nsx_dataheader_spec_v30_ptp(
10361048
filesize = self._get_file_size(filename)
10371049

10381050
data_header = {}
1039-
index = 0
1040-
10411051
if offset is None:
10421052
# This is read as an uint32 numpy scalar from the header so we transform it to python int
1043-
offset = int(self._nsx_basic_header[nsx_nb]["bytes_in_headers"])
1053+
header_size = int(self._nsx_basic_header[nsx_nb]["bytes_in_headers"])
1054+
else:
1055+
header_size = offset
10441056

10451057
ptp_dt = [
10461058
("reserved", "uint8"),
10471059
("timestamps", "uint64"),
10481060
("num_data_points", "uint32"),
1049-
("samples", "int16", self._nsx_basic_header[nsx_nb]["channel_count"]),
1061+
("samples", "int16", (self._nsx_basic_header[nsx_nb]["channel_count"],)),
10501062
]
1051-
npackets = int((filesize - offset) / np.dtype(ptp_dt).itemsize)
1052-
struct_arr = np.memmap(filename, dtype=ptp_dt, shape=npackets, offset=offset, mode="r")
1063+
npackets = int((filesize - header_size) / np.dtype(ptp_dt).itemsize)
1064+
struct_arr = np.memmap(filename, dtype=ptp_dt, shape=npackets, offset=header_size, mode="r")
10531065

10541066
if not np.all(struct_arr["num_data_points"] == 1):
10551067
# some packets have more than 1 sample. Not actually ptp. Revert to non-ptp variant.
1056-
return self._read_nsx_dataheader_spec_v22_30(nsx_nb, filesize=filesize, offset=offset)
1057-
1058-
# It is still possible there was a data break and the file has multiple segments.
1059-
# We can no longer rely on the presence of a header indicating a new segment,
1060-
# so we look for timestamp differences greater than double the expected interval.
1061-
_period = self._nsx_basic_header[nsx_nb]["period"] # 30_000 ^-1 s per sample
1062-
_nominal_rate = 30_000 / _period # samples per sec; maybe 30_000 should be ["sample_resolution"]
1063-
_clock_rate = self._nsx_basic_header[nsx_nb]["timestamp_resolution"] # clocks per sec
1064-
clk_per_samp = _clock_rate / _nominal_rate # clk/sec / smp/sec = clk/smp
1065-
seg_thresh_clk = int(2 * clk_per_samp)
1066-
seg_starts = np.hstack((0, 1 + np.argwhere(np.diff(struct_arr["timestamps"]) > seg_thresh_clk).flatten()))
1067-
for seg_ix, seg_start_idx in enumerate(seg_starts):
1068-
if seg_ix < (len(seg_starts) - 1):
1069-
seg_stop_idx = seg_starts[seg_ix + 1]
1070-
else:
1071-
seg_stop_idx = len(struct_arr) - 1
1072-
seg_offset = offset + seg_start_idx * struct_arr.dtype.itemsize
1073-
num_data_pts = seg_stop_idx - seg_start_idx
1068+
return self._read_nsx_dataheader_spec_v22_30(nsx_nb, filesize=filesize, offset=header_size)
1069+
1070+
1071+
# Segment data, at the moment, we segment, where the data has gaps that are longer
1072+
# than twice the sampling period.
1073+
sampling_rate = self._nsx_sampling_frequency[nsx_nb]
1074+
segmentation_threshold = 2.0 / sampling_rate
1075+
1076+
# The raw timestamps are the indices of an ideal clock that ticks at `timestamp_resolution` times per second.
1077+
# We convert this indices to actual timestamps in seconds
1078+
raw_timestamps = struct_arr["timestamps"]
1079+
timestamps_sampling_rate = self._nsx_basic_header[nsx_nb]["timestamp_resolution"] # clocks per sec uint64 or uint32
1080+
timestamps_in_seconds = raw_timestamps / timestamps_sampling_rate
1081+
1082+
time_differences = np.diff(timestamps_in_seconds)
1083+
gap_indices = np.argwhere(time_differences > segmentation_threshold).flatten()
1084+
segment_starts = np.hstack((0, 1 + gap_indices))
1085+
1086+
# Report gaps if any are found
1087+
if len(gap_indices) > 0:
1088+
import warnings
1089+
threshold_ms = segmentation_threshold * 1000
1090+
1091+
# Calculate all gap details in vectorized operations
1092+
gap_durations_seconds = time_differences[gap_indices]
1093+
gap_durations_ms = gap_durations_seconds * 1000
1094+
gap_positions_seconds = timestamps_in_seconds[gap_indices] - timestamps_in_seconds[0]
1095+
1096+
# Build gap detail lines all at once
1097+
gap_detail_lines = [
1098+
f"| {index:>15,} | {pos:>21.6f} | {dur:>21.3f} |\n"
1099+
for index, pos, dur in zip(gap_indices, gap_positions_seconds, gap_durations_ms)
1100+
]
1101+
1102+
segmentation_report_message = (
1103+
f"\nFound {len(gap_indices)} gaps for nsx {nsx_nb} where samples are farther apart than {threshold_ms:.3f} ms.\n"
1104+
f"Data will be segmented at these locations to create {len(segment_starts)} segments.\n\n"
1105+
"Gap Details:\n"
1106+
"+-----------------+-----------------------+-----------------------+\n"
1107+
"| Sample Index | Sample at | Gap Jump |\n"
1108+
"| | (Seconds) | (Milliseconds) |\n"
1109+
"+-----------------+-----------------------+-----------------------+\n"
1110+
+ ''.join(gap_detail_lines) +
1111+
"+-----------------+-----------------------+-----------------------+\n"
1112+
)
1113+
warnings.warn(segmentation_report_message)
1114+
1115+
# Calculate all segment boundaries and derived values in one operation
1116+
segment_boundaries = list(segment_starts) + [len(struct_arr) - 1]
1117+
segment_num_data_points = [segment_boundaries[i+1] - segment_boundaries[i] for i in range(len(segment_starts))]
1118+
1119+
size_of_data_block = struct_arr.dtype.itemsize
1120+
segment_offsets = [header_size + pos * size_of_data_block for pos in segment_starts]
1121+
1122+
num_segments = len(segment_starts)
1123+
for segment_index in range(num_segments):
1124+
seg_offset = segment_offsets[segment_index]
1125+
num_data_pts = segment_num_data_points[segment_index]
10741126
seg_struct_arr = np.memmap(filename, dtype=ptp_dt, shape=num_data_pts, offset=seg_offset, mode="r")
1075-
data_header[seg_ix] = {
1127+
data_header[segment_index] = {
10761128
"header": None,
10771129
"timestamp": seg_struct_arr["timestamps"], # Note, this is an array, not a scalar
10781130
"nb_data_points": num_data_pts,
@@ -1084,7 +1136,7 @@ def _read_nsx_data_spec_v21(self, nsx_nb):
10841136
"""
10851137
Extract nsx data from a 2.1 .nsx file
10861138
"""
1087-
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
1139+
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"
10881140

10891141
# get shape of data
10901142
shape = (
@@ -1127,13 +1179,13 @@ def _read_nsx_data_spec_v30_ptp(self, nsx_nb):
11271179
yielding a timestamp per sample. Blocks can arise
11281180
if the recording was paused by the user.
11291181
"""
1130-
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
1182+
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"
11311183

11321184
ptp_dt = [
11331185
("reserved", "uint8"),
11341186
("timestamps", "uint64"),
11351187
("num_data_points", "uint32"),
1136-
("samples", "int16", self._nsx_basic_header[nsx_nb]["channel_count"]),
1188+
("samples", "int16", (self._nsx_basic_header[nsx_nb]["channel_count"],)),
11371189
]
11381190

11391191
data = {}

neo/test/rawiotest/test_spikegadgetsrawio.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ def test_parse_header_missing_channels(self):
5656
def test_opening_gibberish_file(self):
5757
"""Test that parsing a file without </Configuration> raises ValueError instead of infinite loop."""
5858
# Create a temporary file with gibberish content that doesn't have the required tag
59-
with tempfile.NamedTemporaryFile(mode='wb', suffix='.rec') as temp_file:
59+
with tempfile.NamedTemporaryFile(mode="wb", suffix=".rec") as temp_file:
6060
# Write simple gibberish content without the required </Configuration> tag
6161
temp_file.write(b"gibberish\n")
6262
temp_file.flush()
63-
63+
6464
reader = SpikeGadgetsRawIO(filename=temp_file.name)
6565
with self.assertRaises(ValueError) as cm:
6666
reader.parse_header()
67-
67+
6868
self.assertIn("xml header does not contain '</Configuration>'", str(cm.exception))

0 commit comments

Comments
 (0)