Skip to content

Commit 613e151

Browse files
committed
fix
2 parents 72907f8 + 01f5577 commit 613e151

File tree

7 files changed

+826
-242
lines changed

7 files changed

+826
-242
lines changed

neo/rawio/blackrockrawio.py

Lines changed: 94 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,23 @@ def _parse_header(self):
290290
self._nsx_basic_header = {}
291291
self._nsx_ext_header = {}
292292
self._nsx_data_header = {}
293+
self._nsx_sampling_frequency = {}
293294

295+
# Read headers
294296
for nsx_nb in self._avail_nsx:
295297
spec_version = self._nsx_spec[nsx_nb] = self._extract_nsx_file_spec(nsx_nb)
296298
# read nsx headers
297299
self._nsx_basic_header[nsx_nb], self._nsx_ext_header[nsx_nb] = self._read_nsx_header(spec_version, nsx_nb)
298300

301+
# The Blackrock defines period as the number of 1/30_000 seconds between data points
302+
# E.g. it is 1 for 30_000, 3 for 10_000, etc
303+
nsx_period = self._nsx_basic_header[nsx_nb]["period"]
304+
sampling_rate = 30_000.0 / nsx_period
305+
self._nsx_sampling_frequency[nsx_nb] = float(sampling_rate)
306+
307+
# Parase data packages
308+
for nsx_nb in self._avail_nsx:
309+
299310
# The only way to know if it is the Precision Time Protocol of file spec 3.0
300311
# is to check for nanosecond timestamp resolution.
301312
is_ptp_variant = (
@@ -352,7 +363,10 @@ def _parse_header(self):
352363
self._match_nsx_and_nev_segment_ids(nsx_nb)
353364

354365
self.nsx_datas = {}
355-
self.sig_sampling_rates = {}
366+
# Keep public attribute for backward compatibility but let's use the private one and maybe deprecate this at some point
367+
self.sig_sampling_rates = {
368+
nsx_number: self._nsx_sampling_frequency[nsx_number] for nsx_number in self.nsx_to_load
369+
}
356370
if len(self.nsx_to_load) > 0:
357371
for nsx_nb in self.nsx_to_load:
358372
basic_header = self._nsx_basic_header[nsx_nb]
@@ -369,8 +383,7 @@ def _parse_header(self):
369383
data_spec = spec_version
370384
self.nsx_datas[nsx_nb] = self._read_nsx_data(data_spec, nsx_nb)
371385

372-
sr = float(self.main_sampling_rate / basic_header["period"])
373-
self.sig_sampling_rates[nsx_nb] = sr
386+
sr = self._nsx_sampling_frequency[nsx_nb]
374387

375388
if spec_version in ["2.2", "2.3", "3.0"]:
376389
ext_header = self._nsx_ext_header[nsx_nb]
@@ -439,7 +452,7 @@ def _parse_header(self):
439452
length = self.nsx_datas[nsx_nb][data_bl].shape[0]
440453
if self._nsx_data_header[nsx_nb] is None:
441454
t_start = 0.0
442-
t_stop = max(t_stop, length / self.sig_sampling_rates[nsx_nb])
455+
t_stop = max(t_stop, length / self._nsx_sampling_frequency[nsx_nb])
443456
else:
444457
timestamps = self._nsx_data_header[nsx_nb][data_bl]["timestamp"]
445458
if hasattr(timestamps, "size") and timestamps.size == length:
@@ -448,7 +461,7 @@ def _parse_header(self):
448461
t_stop = max(t_stop, timestamps[-1] / ts_res + sec_per_samp)
449462
else:
450463
t_start = timestamps / ts_res
451-
t_stop = max(t_stop, t_start + length / self.sig_sampling_rates[nsx_nb])
464+
t_stop = max(t_stop, t_start + length / self._nsx_sampling_frequency[nsx_nb])
452465
self._sigs_t_starts[nsx_nb].append(t_start)
453466

454467
if self._avail_files["nev"]:
@@ -964,36 +977,82 @@ def _read_nsx_dataheader_ptp(self, nsx_nb, filesize=None, offset=None):
964977

965978
if offset is None:
966979
# This is read as an uint32 numpy scalar from the header so we transform it to python int
967-
offset = int(self._nsx_basic_header[nsx_nb]["bytes_in_headers"])
980+
header_size = int(self._nsx_basic_header[nsx_nb]["bytes_in_headers"])
981+
else:
982+
header_size = offset
968983

969984
# Use the dictionary for PTP data type
970985
channel_count = int(self._nsx_basic_header[nsx_nb]["channel_count"])
971986
ptp_dt = NSX_DATA_HEADER_TYPES["3.0-ptp"](channel_count)
972-
npackets = int((filesize - offset) / np.dtype(ptp_dt).itemsize)
973-
struct_arr = np.memmap(filename, dtype=ptp_dt, shape=npackets, offset=offset, mode="r")
987+
npackets = int((filesize - header_size) / np.dtype(ptp_dt).itemsize)
988+
struct_arr = np.memmap(filename, dtype=ptp_dt, shape=npackets, offset=header_size, mode="r")
974989

975990
if not np.all(struct_arr["num_data_points"] == 1):
976991
# some packets have more than 1 sample. Not actually ptp. Revert to non-ptp variant.
977-
return self._read_nsx_dataheader_standard("3.0", nsx_nb, filesize=filesize, offset=offset)
978-
979-
# It is still possible there was a data break and the file has multiple segments.
980-
# We can no longer rely on the presence of a header indicating a new segment,
981-
# so we look for timestamp differences greater than double the expected interval.
982-
_period = self._nsx_basic_header[nsx_nb]["period"] # 30_000 ^-1 s per sample
983-
_nominal_rate = 30_000 / _period # samples per sec; maybe 30_000 should be ["sample_resolution"]
984-
_clock_rate = self._nsx_basic_header[nsx_nb]["timestamp_resolution"] # clocks per sec
985-
clk_per_samp = _clock_rate / _nominal_rate # clk/sec / smp/sec = clk/smp
986-
seg_thresh_clk = int(2 * clk_per_samp)
987-
seg_starts = np.hstack((0, 1 + np.argwhere(np.diff(struct_arr["timestamps"]) > seg_thresh_clk).flatten()))
988-
for seg_ix, seg_start_idx in enumerate(seg_starts):
989-
if seg_ix < (len(seg_starts) - 1):
990-
seg_stop_idx = seg_starts[seg_ix + 1]
991-
else:
992-
seg_stop_idx = len(struct_arr) - 1
993-
seg_offset = offset + seg_start_idx * struct_arr.dtype.itemsize
994-
num_data_pts = seg_stop_idx - seg_start_idx
992+
return self._read_nsx_dataheader_standard("3.0", nsx_nb, filesize=filesize, offset=header_size)
993+
994+
# Segment data, at the moment, we segment, where the data has gaps that are longer
995+
# than twice the sampling period.
996+
sampling_rate = self._nsx_sampling_frequency[nsx_nb]
997+
segmentation_threshold = 2.0 / sampling_rate
998+
999+
# The raw timestamps are the indices of an ideal clock that ticks at `timestamp_resolution` times per second.
1000+
# We convert this indices to actual timestamps in seconds
1001+
raw_timestamps = struct_arr["timestamps"]
1002+
timestamps_sampling_rate = self._nsx_basic_header[nsx_nb][
1003+
"timestamp_resolution"
1004+
] # clocks per sec uint64 or uint32
1005+
timestamps_in_seconds = raw_timestamps / timestamps_sampling_rate
1006+
1007+
time_differences = np.diff(timestamps_in_seconds)
1008+
gap_indices = np.argwhere(time_differences > segmentation_threshold).flatten()
1009+
segment_starts = np.hstack((0, 1 + gap_indices))
1010+
1011+
# Report gaps if any are found
1012+
if len(gap_indices) > 0:
1013+
import warnings
1014+
1015+
threshold_ms = segmentation_threshold * 1000
1016+
1017+
# Calculate all gap details in vectorized operations
1018+
gap_durations_seconds = time_differences[gap_indices]
1019+
gap_durations_ms = gap_durations_seconds * 1000
1020+
gap_positions_seconds = timestamps_in_seconds[gap_indices] - timestamps_in_seconds[0]
1021+
1022+
# Build gap detail lines all at once
1023+
gap_detail_lines = [
1024+
f"| {index:>15,} | {pos:>21.6f} | {dur:>21.3f} |\n"
1025+
for index, pos, dur in zip(gap_indices, gap_positions_seconds, gap_durations_ms)
1026+
]
1027+
1028+
segmentation_report_message = (
1029+
f"\nFound {len(gap_indices)} gaps for nsx {nsx_nb} where samples are farther apart than {threshold_ms:.3f} ms.\n"
1030+
f"Data will be segmented at these locations to create {len(segment_starts)} segments.\n\n"
1031+
"Gap Details:\n"
1032+
"+-----------------+-----------------------+-----------------------+\n"
1033+
"| Sample Index | Sample at | Gap Jump |\n"
1034+
"| | (Seconds) | (Milliseconds) |\n"
1035+
"+-----------------+-----------------------+-----------------------+\n"
1036+
+ "".join(gap_detail_lines)
1037+
+ "+-----------------+-----------------------+-----------------------+\n"
1038+
)
1039+
warnings.warn(segmentation_report_message)
1040+
1041+
# Calculate all segment boundaries and derived values in one operation
1042+
segment_boundaries = list(segment_starts) + [len(struct_arr) - 1]
1043+
segment_num_data_points = [
1044+
segment_boundaries[i + 1] - segment_boundaries[i] for i in range(len(segment_starts))
1045+
]
1046+
1047+
size_of_data_block = struct_arr.dtype.itemsize
1048+
segment_offsets = [header_size + pos * size_of_data_block for pos in segment_starts]
1049+
1050+
num_segments = len(segment_starts)
1051+
for segment_index in range(num_segments):
1052+
seg_offset = segment_offsets[segment_index]
1053+
num_data_pts = segment_num_data_points[segment_index]
9951054
seg_struct_arr = np.memmap(filename, dtype=ptp_dt, shape=num_data_pts, offset=seg_offset, mode="r")
996-
data_header[seg_ix] = {
1055+
data_header[segment_index] = {
9971056
"header": None,
9981057
"timestamp": seg_struct_arr["timestamps"], # Note, this is an array, not a scalar
9991058
"nb_data_points": num_data_pts,
@@ -1028,7 +1087,7 @@ def _read_nsx_data_v21(self, nsx_nb):
10281087
"""
10291088
Extract nsx data from a 2.1 .nsx file
10301089
"""
1031-
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
1090+
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"
10321091

10331092
# get shape of data
10341093
shape = (
@@ -1071,7 +1130,7 @@ def _read_nsx_data_ptp(self, nsx_nb):
10711130
yielding a timestamp per sample. Blocks can arise
10721131
if the recording was paused by the user.
10731132
"""
1074-
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
1133+
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"
10751134

10761135
# Use the dictionary for PTP data type
10771136
channel_count = int(self._nsx_basic_header[nsx_nb]["channel_count"])
@@ -1146,10 +1205,6 @@ def _read_nev_header(self, spec, filename):
11461205

11471206
nev_basic_header = np.fromfile(filename, count=1, dtype=dt0)[0]
11481207

1149-
# Get extended header types for this spec
1150-
header_types = NEV_EXT_HEADER_TYPES_BY_SPEC[spec]
1151-
1152-
# extended header reading
11531208
shape = nev_basic_header["nb_ext_headers"]
11541209
offset_dt0 = np.dtype(dt0).itemsize
11551210

@@ -1158,6 +1213,10 @@ def _read_nev_header(self, spec, filename):
11581213

11591214
raw_ext_header = np.memmap(filename, offset=offset_dt0, dtype=dt1, shape=shape, mode="r")
11601215

1216+
1217+
# Get extended header types for this spec
1218+
header_types = NEV_EXT_HEADER_TYPES_BY_SPEC[spec]
1219+
11611220
# Parse extended headers by packet type
11621221
# Strategy: view() entire array first, then mask for efficiency
11631222
# Since all NEV extended header packets are fixed-width (32 bytes), temporarily
@@ -2399,8 +2458,8 @@ def _is_set(self, flag, pos):
23992458
# PTP variant has a completely different structure with samples embedded
24002459
"3.0-ptp": lambda channel_count: [
24012460
("reserved", "uint8"),
2402-
("timestamps", "uint64"),
2461+
("timestamps", "uint64"),
24032462
("num_data_points", "uint32"),
2404-
("samples", "int16", channel_count)
2463+
("samples", "int16", (channel_count,))
24052464
]
24062465
}

neo/rawio/brainvisionrawio.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def _parse_header(self):
5555
marker_filename = self.filename.replace(bname, vhdr_header["Common Infos"]["MarkerFile"])
5656
binary_filename = self.filename.replace(bname, vhdr_header["Common Infos"]["DataFile"])
5757

58+
marker_filename = self._ensure_filename(marker_filename, "marker", "MarkerFile")
59+
binary_filename = self._ensure_filename(binary_filename, "data", "DataFile")
60+
5861
if vhdr_header["Common Infos"]["DataFormat"] != "BINARY":
5962
raise NeoReadWriteError(
6063
f"Only `BINARY` format has been implemented. Current Data Format is {vhdr_header['Common Infos']['DataFormat']}"
@@ -236,6 +239,51 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index)
236239
def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id):
237240
return self._buffer_descriptions[block_index][seg_index][buffer_id]
238241

242+
def _ensure_filename(self, filename, kind, entry_name):
243+
if not os.path.exists(filename):
244+
# file not found, subsequent import stage would fail
245+
ext = os.path.splitext(filename)[1]
246+
# Check if we can fall back to a file with the same prefix as the .vhdr.
247+
# This can happen when users rename their files but forget to edit the
248+
# .vhdr file to fix the path reference to the binary and marker files,
249+
# in which case import will fail. These files come in triples, like:
250+
# myfile.vhdr, myfile.eeg and myfile.vmrk; this code will thus pick
251+
# the next best alternative.
252+
alt_name = self.filename.replace(".vhdr", ext)
253+
if os.path.exists(alt_name):
254+
self.logger.warning(
255+
f"The {kind} file {filename} was not found, but found a file whose "
256+
f"prefix matched the .vhdr ({os.path.basename(alt_name)}). Using "
257+
f"this file instead."
258+
)
259+
filename = alt_name
260+
else:
261+
# we neither found the file referenced in the .vhdr file nor a file of
262+
# same name as header with the desired extension; most likely a file went
263+
# missing or was renamed in an inconsistent fashion; generate a useful
264+
# error message
265+
header_dname = os.path.dirname(self.filename)
266+
header_bname = os.path.basename(self.filename)
267+
referenced_bname = os.path.basename(filename)
268+
alt_bname = os.path.basename(alt_name)
269+
if alt_bname != referenced_bname:
270+
# this is only needed when the two candidate file names differ
271+
detail = (
272+
f" is named either as per the {entry_name}={referenced_bname} " f"line in the .vhdr file, or"
273+
)
274+
else:
275+
# we omit it if we can to make it less confusing
276+
detail = ""
277+
self.logger.error(
278+
f"Did not find the {kind} file associated with .vhdr (header) "
279+
f"file {header_bname!r} in folder {header_dname!r}.\n Please make "
280+
f"sure the file{detail} is named the same way as the .vhdr file, but "
281+
f"ending in {ext} (i.e. {alt_bname}).\n The import will likely fail, "
282+
f"but if it goes through, you can ignore this message (the check "
283+
f"can misfire on networked file systems)."
284+
)
285+
return filename
286+
239287

240288
def read_brainvsion_soup(filename):
241289
with open(filename, "r", encoding="utf8") as f:

0 commit comments

Comments
 (0)