Skip to content

Commit f279132

Browse files
authored
Merge pull request #1740 from h-mayorquin/fix_blackrock
Blackrock add block validation
2 parents 40f5eaf + 2240a29 commit f279132

File tree

1 file changed

+86
-60
lines changed

1 file changed

+86
-60
lines changed

neo/rawio/blackrockrawio.py

Lines changed: 86 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -323,19 +323,21 @@ def _parse_header(self):
323323
self.__nsx_data_header = {}
324324

325325
for nsx_nb in self._avail_nsx:
326-
spec = self.__nsx_spec[nsx_nb] = self.__extract_nsx_file_spec(nsx_nb)
326+
spec_version = self.__nsx_spec[nsx_nb] = self.__extract_nsx_file_spec(nsx_nb)
327327
# read nsx headers
328-
self.__nsx_basic_header[nsx_nb], self.__nsx_ext_header[nsx_nb] = self.__nsx_header_reader[spec](nsx_nb)
328+
nsx_header_reader = self.__nsx_header_reader[spec_version]
329+
self.__nsx_basic_header[nsx_nb], self.__nsx_ext_header[nsx_nb] = nsx_header_reader(nsx_nb)
329330

330-
# The only way to know if it is the PTP-variant of file spec 3.0
331+
# The only way to know if it is the Precision Time Protocol of file spec 3.0
331332
# is to check for nanosecond timestamp resolution.
332-
if (
333+
is_ptp_variant = (
333334
"timestamp_resolution" in self.__nsx_basic_header[nsx_nb].dtype.names
334335
and self.__nsx_basic_header[nsx_nb]["timestamp_resolution"] == 1_000_000_000
335-
):
336+
)
337+
if is_ptp_variant:
336338
nsx_dataheader_reader = self.__nsx_dataheader_reader["3.0-ptp"]
337339
else:
338-
nsx_dataheader_reader = self.__nsx_dataheader_reader[spec]
340+
nsx_dataheader_reader = self.__nsx_dataheader_reader[spec_version]
339341
# for nsxdef get_analogsignal_shape(self, block_index, seg_index):
340342
self.__nsx_data_header[nsx_nb] = nsx_dataheader_reader(nsx_nb)
341343

@@ -355,8 +357,12 @@ def _parse_header(self):
355357
else:
356358
raise (ValueError("nsx_to_load is wrong"))
357359

358-
if not all(nsx_nb in self._avail_nsx for nsx_nb in self.nsx_to_load):
359-
raise FileNotFoundError(f"nsx_to_load does not match available nsx list")
360+
missing_nsx_files = [nsx_nb for nsx_nb in self.nsx_to_load if nsx_nb not in self._avail_nsx]
361+
if missing_nsx_files:
362+
missing_list = ", ".join(f"ns{nsx_nb}" for nsx_nb in missing_nsx_files)
363+
raise FileNotFoundError(
364+
f"Requested NSX file(s) not found: {missing_list}. Available NSX files: {self._avail_nsx}"
365+
)
360366

361367
# check that all files come from the same specification
362368
all_spec = [self.__nsx_spec[nsx_nb] for nsx_nb in self.nsx_to_load]
@@ -381,27 +387,29 @@ def _parse_header(self):
381387
self.sig_sampling_rates = {}
382388
if len(self.nsx_to_load) > 0:
383389
for nsx_nb in self.nsx_to_load:
384-
spec = self.__nsx_spec[nsx_nb]
385-
# The only way to know if it is the PTP-variant of file spec 3.0
390+
basic_header = self.__nsx_basic_header[nsx_nb]
391+
spec_version = self.__nsx_spec[nsx_nb]
392+
# The only way to know if it is the Precision Time Protocol of file spec 3.0
386393
# is to check for nanosecond timestamp resolution.
387-
if (
388-
"timestamp_resolution" in self.__nsx_basic_header[nsx_nb].dtype.names
389-
and self.__nsx_basic_header[nsx_nb]["timestamp_resolution"] == 1_000_000_000
390-
):
394+
is_ptp_variant = (
395+
"timestamp_resolution" in basic_header.dtype.names
396+
and basic_header["timestamp_resolution"] == 1_000_000_000
397+
)
398+
if is_ptp_variant:
391399
_data_reader_fun = self.__nsx_data_reader["3.0-ptp"]
392400
else:
393-
_data_reader_fun = self.__nsx_data_reader[spec]
401+
_data_reader_fun = self.__nsx_data_reader[spec_version]
394402
self.nsx_datas[nsx_nb] = _data_reader_fun(nsx_nb)
395403

396-
sr = float(self.main_sampling_rate / self.__nsx_basic_header[nsx_nb]["period"])
404+
sr = float(self.main_sampling_rate / basic_header["period"])
397405
self.sig_sampling_rates[nsx_nb] = sr
398406

399-
if spec in ["2.2", "2.3", "3.0"]:
407+
if spec_version in ["2.2", "2.3", "3.0"]:
400408
ext_header = self.__nsx_ext_header[nsx_nb]
401-
elif spec == "2.1":
409+
elif spec_version == "2.1":
402410
ext_header = []
403411
keys = ["labels", "units", "min_analog_val", "max_analog_val", "min_digital_val", "max_digital_val"]
404-
params = self.__nsx_params[spec](nsx_nb)
412+
params = self.__nsx_params[spec_version](nsx_nb)
405413
for i in range(len(params["labels"])):
406414
d = {}
407415
for key in keys:
@@ -415,11 +423,11 @@ def _parse_header(self):
415423
signal_buffers.append((stream_name, buffer_id))
416424
signal_streams.append((stream_name, stream_id, buffer_id))
417425
for i, chan in enumerate(ext_header):
418-
if spec in ["2.2", "2.3", "3.0"]:
426+
if spec_version in ["2.2", "2.3", "3.0"]:
419427
ch_name = chan["electrode_label"].decode()
420428
ch_id = str(chan["electrode_id"])
421429
units = chan["units"].decode()
422-
elif spec == "2.1":
430+
elif spec_version == "2.1":
423431
ch_name = chan["labels"]
424432
ch_id = str(self.__nsx_ext_header[nsx_nb][i]["electrode_id"])
425433
units = chan["units"]
@@ -809,7 +817,7 @@ def __extract_nsx_file_spec(self, nsx_nb):
809817
"""
810818
Extract file specification from an .nsx file.
811819
"""
812-
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
820+
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"
813821

814822
# Header structure of files specification 2.2 and higher. For files 2.1
815823
# and lower, the entries ver_major and ver_minor are not supported.
@@ -829,7 +837,7 @@ def __extract_nev_file_spec(self):
829837
"""
830838
Extract file specification from an .nev file
831839
"""
832-
filename = ".".join([self._filenames["nev"], "nev"])
840+
filename = f"{self._filenames['nev']}.nev"
833841
# Header structure of files specification 2.2 and higher. For files 2.1
834842
# and lower, the entries ver_major and ver_minor are not supported.
835843
dt0 = [("file_id", "S8"), ("ver_major", "uint8"), ("ver_minor", "uint8")]
@@ -879,7 +887,7 @@ def __read_nsx_header_variant_b(self, nsx_nb):
879887
"""
880888
Extract nsx header information from a 2.2 or 2.3 .nsx file
881889
"""
882-
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
890+
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"
883891

884892
# basic header (file_id: NEURALCD)
885893
dt0 = [
@@ -911,7 +919,6 @@ def __read_nsx_header_variant_b(self, nsx_nb):
911919

912920
# extended header (type: CC)
913921
offset_dt0 = np.dtype(dt0).itemsize
914-
shape = nsx_basic_header["channel_count"]
915922
dt1 = [
916923
("type", "S2"),
917924
("electrode_id", "uint16"),
@@ -930,28 +937,32 @@ def __read_nsx_header_variant_b(self, nsx_nb):
930937
# filter settings used to create nsx from source signal
931938
("hi_freq_corner", "uint32"),
932939
("hi_freq_order", "uint32"),
933-
("hi_freq_type", "uint16"), # 0=None, 1=Butterworth
940+
("hi_freq_type", "uint16"), # 0=None, 1=Butterworth, 2=Chebyshev
934941
("lo_freq_corner", "uint32"),
935942
("lo_freq_order", "uint32"),
936943
("lo_freq_type", "uint16"),
937-
] # 0=None, 1=Butterworth
944+
] # 0=None, 1=Butterworth, 2=Chebyshev
938945

939-
nsx_ext_header = np.memmap(filename, shape=shape, offset=offset_dt0, dtype=dt1, mode="r")
946+
channel_count = int(nsx_basic_header["channel_count"])
947+
nsx_ext_header = np.memmap(filename, shape=channel_count, offset=offset_dt0, dtype=dt1, mode="r")
940948

941949
return nsx_basic_header, nsx_ext_header
942950

943951
def __read_nsx_dataheader(self, nsx_nb, offset):
944952
"""
945953
Reads data header following the given offset of an nsx file.
946954
"""
947-
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
955+
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"
948956

949-
ts_size = "uint64" if self.__nsx_basic_header[nsx_nb]["ver_major"] >= 3 else "uint32"
957+
major_version = self.__nsx_basic_header[nsx_nb]["ver_major"]
958+
ts_size = "uint64" if major_version >= 3 else "uint32"
959+
960+
# dtypes data header, the header flag is always set to 1
961+
dt2 = [("header_flag", "uint8"), ("timestamp", ts_size), ("nb_data_points", "uint32")]
950962

951-
# dtypes data header
952-
dt2 = [("header", "uint8"), ("timestamp", ts_size), ("nb_data_points", "uint32")]
963+
packet_header = np.memmap(filename, dtype=dt2, shape=1, offset=offset, mode="r")[0]
953964

954-
return np.memmap(filename, dtype=dt2, shape=1, offset=offset, mode="r")[0]
965+
return packet_header
955966

956967
def __read_nsx_dataheader_variant_a(self, nsx_nb, filesize=None, offset=None):
957968
"""
@@ -971,32 +982,46 @@ def __read_nsx_dataheader_variant_b(
971982
Reads the nsx data header for each data block following the offset of
972983
file spec 2.2, 2.3, and 3.0.
973984
"""
974-
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
985+
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"
975986

976-
filesize = self.__get_file_size(filename)
987+
filesize_bytes = self.__get_file_size(filename)
977988

978989
data_header = {}
979-
index = 0
980-
981990
if offset is None:
982-
offset = self.__nsx_basic_header[nsx_nb]["bytes_in_headers"]
983-
984-
while offset < filesize:
985-
dh = self.__read_nsx_dataheader(nsx_nb, offset)
986-
data_header[index] = {
987-
"header": dh["header"],
988-
"timestamp": dh["timestamp"],
989-
"nb_data_points": dh["nb_data_points"],
990-
"offset_to_data_block": offset + dh.dtype.itemsize,
991+
offset_to_first_data_block = int(self.__nsx_basic_header[nsx_nb]["bytes_in_headers"])
992+
else:
993+
offset_to_first_data_block = int(offset)
994+
995+
channel_count = int(self.__nsx_basic_header[nsx_nb]["channel_count"])
996+
current_offset_bytes = offset_to_first_data_block
997+
data_block_index = 0
998+
while current_offset_bytes < filesize_bytes:
999+
packet_header = self.__read_nsx_dataheader(nsx_nb, current_offset_bytes)
1000+
header_flag = packet_header["header_flag"]
1001+
# NSX data blocks must have header_flag = 1, other values indicate file corruption
1002+
if header_flag != 1:
1003+
raise ValueError(
1004+
f"Invalid NSX data block header at offset {current_offset_bytes:#x} in ns{nsx_nb} file. "
1005+
f"Expected header_flag=1, got {header_flag}. "
1006+
f"This may indicate file corruption or unsupported NSX format variant. "
1007+
f"Block index: {data_block_index}, File size: {filesize_bytes} bytes"
1008+
)
1009+
timestamp = packet_header["timestamp"]
1010+
num_data_points = int(packet_header["nb_data_points"])
1011+
offset_to_data_block_start = current_offset_bytes + packet_header.dtype.itemsize
1012+
1013+
data_header[data_block_index] = {
1014+
"header": header_flag,
1015+
"timestamp": timestamp,
1016+
"nb_data_points": num_data_points,
1017+
"offset_to_data_block": offset_to_data_block_start,
9911018
}
9921019

993-
# data size = number of data points * (2bytes * number of channels)
994-
# use of `int` avoids overflow problem
995-
data_size = int(dh["nb_data_points"]) * int(self.__nsx_basic_header[nsx_nb]["channel_count"]) * 2
996-
# define new offset (to possible next data block)
997-
offset = int(data_header[index]["offset_to_data_block"]) + data_size
1020+
# Jump to the next data block, the data is encoded as int16
1021+
data_block_size_bytes = num_data_points * channel_count * np.dtype("int16").itemsize
1022+
current_offset_bytes = offset_to_data_block_start + data_block_size_bytes
9981023

999-
index += 1
1024+
data_block_index += 1
10001025

10011026
return data_header
10021027

@@ -1082,19 +1107,20 @@ def __read_nsx_data_variant_b(self, nsx_nb):
10821107
Extract nsx data (blocks) from a 2.2, 2.3, or 3.0 .nsx file.
10831108
Blocks can arise if the recording was paused by the user.
10841109
"""
1085-
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
1110+
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"
10861111

10871112
data = {}
1088-
for data_bl in self.__nsx_data_header[nsx_nb].keys():
1113+
data_header = self.__nsx_data_header[nsx_nb]
1114+
number_of_channels = int(self.__nsx_basic_header[nsx_nb]["channel_count"])
1115+
1116+
for data_block in data_header.keys():
10891117
# get shape and offset of data
1090-
shape = (
1091-
int(self.__nsx_data_header[nsx_nb][data_bl]["nb_data_points"]),
1092-
int(self.__nsx_basic_header[nsx_nb]["channel_count"]),
1093-
)
1094-
offset = int(self.__nsx_data_header[nsx_nb][data_bl]["offset_to_data_block"])
1118+
number_of_samples = int(data_header[data_block]["nb_data_points"])
1119+
shape = (number_of_samples, number_of_channels)
1120+
offset = int(data_header[data_block]["offset_to_data_block"])
10951121

10961122
# read data
1097-
data[data_bl] = np.memmap(filename, dtype="int16", shape=shape, offset=offset, mode="r")
1123+
data[data_block] = np.memmap(filename, dtype="int16", shape=shape, offset=offset, mode="r")
10981124

10991125
return data
11001126

0 commit comments

Comments
 (0)