Skip to content

Commit b8adefb

Browse files
committed
wip updates
based on Heberto + more intan reading.
1 parent 44b8792 commit b8adefb

File tree

1 file changed

+51
-18
lines changed

1 file changed

+51
-18
lines changed

neo/rawio/intanrawio.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
1919
"""
2020
from pathlib import Path
21+
import os
2122
from collections import OrderedDict
2223
from packaging.version import Version as V
2324

@@ -96,7 +97,7 @@ def _parse_header(self):
9697
else:
9798
self.file_format = 'header-attached'
9899

99-
self._global_info, self._ordered_channels, data_dtype, header_size, self._block_size = read_rhd(
100+
self._global_info, self._ordered_channels, data_dtype, header_size, self._block_size, channel_number_dict = read_rhd(
100101
self.filename, self.file_format
101102
)
102103

@@ -109,9 +110,19 @@ def _parse_header(self):
109110
for stream_index, (stream_index_key, stream_datatype) in enumerate(data_dtype.items()):
110111
# for 'one-file-per-signal' we have one memory map / neo stream
111112
if self.file_format == "one-file-per-signal":
112-
self._raw_data[stream_index] = np.memmap(
113-
raw_file_paths_dict[stream_index_key], dtype=stream_datatype, mode="r"
114-
)
113+
n_chans = channel_number_dict[stream_index_key]
114+
if stream_index_key == 4 or stream_index_key == 5:
115+
n_samples = int(os.path.getsize(raw_file_paths_dict[stream_index_key]) / 2) # uint16 2 bytes
116+
else:
117+
n_samples = int(os.path.getsize(raw_file_paths_dict[stream_index_key]) / (n_chans * 2))# unit16 2 bytes
118+
if stream_index_key != 6:
119+
self._raw_data[stream_index] = np.memmap(
120+
raw_file_paths_dict[stream_index_key], dtype=[stream_datatype[0]], shape = (n_chans, n_samples), mode="r"
121+
).T
122+
else:
123+
self._raw_data[stream_index] = np.memmap(
124+
raw_file_paths_dict[stream_index_key], dtype=[stream_datatype[0]], mode="r"
125+
)
115126
# for one-file-per-channel we have one memory map / channel stored as a list / neo stream
116127
else:
117128
self._raw_data[stream_index] = []
@@ -174,14 +185,14 @@ def _parse_header(self):
174185
elif self.file_format == 'one-file-per-signal':
175186
self._max_sigs_length = max(
176187
[
177-
raw_data.size * self._block_size
188+
raw_data.size
178189
for raw_data in self._raw_data.values()
179190
]
180191
)
181192
else:
182193
self._max_sigs_length = max(
183194
[
184-
len(raw_data) * raw_data[0].size * self._block_size
195+
len(raw_data) * raw_data[0].size
185196
for raw_data in self._raw_data.values()
186197
]
187198
)
@@ -220,7 +231,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
220231
if self.file_format == "header-attached":
221232
size = self._raw_data[chan_name0].size
222233
elif self.file_format == 'one-file-per-signal':
223-
size = self._raw_data[stream_index][chan_name0].size
234+
size = self._raw_data[stream_index][:,0].size
224235
else:
225236
size = self._raw_data[stream_index][0][chan_name0].size
226237
return size
@@ -249,7 +260,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
249260
if self.file_format == 'header-attached':
250261
shape = self._raw_data[channel_names[0]].shape
251262
elif self.file_format == 'one-file-per-signal':
252-
shape = self._raw_data[stream_index][channel_names[0]].shape
263+
shape = self._raw_data[stream_index][:, 0].shape
253264
else:
254265
if channel_indexes_are_none:
255266
shape = self._raw_data[stream_index][0][channel_names[0]].shape
@@ -272,7 +283,10 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
272283
if self.file_format == 'header-attached':
273284
data_chan = self._raw_data[chan_name]
274285
elif self.file_format == 'one-file-per-signal':
275-
data_chan = self._raw_data[stream_index][chan_name]
286+
if channel_indexes_are_none:
287+
data_chan = self._raw_data[stream_index][:, i]
288+
else:
289+
data_chan = self._raw_data[stream_index][:, channel_indexes[i]]
276290
else:
277291
if channel_indexes_are_none:
278292
data_chan = self._raw_data[stream_index][i][chan_name]
@@ -590,6 +604,8 @@ def read_rhd(filename, file_format: str):
590604
if bool(chan_info["channel_enabled"]):
591605
channels_by_type[chan_info["signal_type"]].append(chan_info)
592606

607+
channel_number_dict = {i: len(channels_by_type[i]) for i in range(6)}
608+
593609
header_size = f.tell()
594610

595611
sr = global_info["sampling_rate"]
@@ -606,25 +622,32 @@ def read_rhd(filename, file_format: str):
606622
if file_format == "header-attached":
607623
data_dtype = [("timestamp", "int32", BLOCK_SIZE)]
608624
else:
609-
data_dtype[6] = [("timestamp", "int32", BLOCK_SIZE)]
625+
data_dtype[6] = [("timestamp", "int32",)]
626+
channel_number_dict[6] = 1
610627
else:
611628
if file_format == "header-attached":
612629
data_dtype = [("timestamp", "uint32", BLOCK_SIZE)]
613630
else:
614-
data_dtype[6] = [("timestamp", "uint32", BLOCK_SIZE)]
631+
data_dtype[6] = [("timestamp", "uint32",)]
632+
channel_number_dict[6] = 1
615633

616634
# 0: RHD2000 amplifier channel
617635
for chan_info in channels_by_type[0]:
618636
name = chan_info["custom_channel_name"]
619637
chan_info["sampling_rate"] = sr
620638
chan_info["units"] = "uV"
621639
chan_info["gain"] = 0.195
622-
chan_info["offset"] = -32768 * 0.195
640+
if file_format == "header-attached":
641+
chan_info["offset"] = -32768 * 0.195
642+
else:
643+
chan_info["offset"] = 0.0
623644
ordered_channels.append(chan_info)
624645
if file_format == "header-attached":
625646
data_dtype += [(name, "uint16", BLOCK_SIZE)]
647+
elif file_format == 'one-file-per-signal':
648+
data_dtype[0] = "int16"
626649
else:
627-
data_dtype[0] += [(name, "uint16", BLOCK_SIZE)]
650+
data_dtype[0] += [(name, "int16")]
628651

629652
# 1: RHD2000 auxiliary input channel
630653
for chan_info in channels_by_type[1]:
@@ -636,8 +659,11 @@ def read_rhd(filename, file_format: str):
636659
ordered_channels.append(chan_info)
637660
if file_format == "header-attached":
638661
data_dtype += [(name, "uint16", BLOCK_SIZE // 4)]
662+
elif file_format == "one-file-per-signal":
663+
data_dtype[1] = "uint16"
639664
else:
640-
data_dtype[1] += [(name, "uint16", BLOCK_SIZE // 4)]
665+
data_dtype[1] += [(name, "uint16")]
666+
641667

642668
# 2: RHD2000 supply voltage channel
643669
for chan_info in channels_by_type[2]:
@@ -649,8 +675,10 @@ def read_rhd(filename, file_format: str):
649675
ordered_channels.append(chan_info)
650676
if file_format == "header-attached":
651677
data_dtype += [(name, "uint16")]
678+
elif file_format == "one-file-per-signal":
679+
data_dtype[1] = "uint16"
652680
else:
653-
data_dtype[1] += [(name, "uint16", BLOCK_SIZE // 4)]
681+
data_dtype[1] += [(name, "uint16")]
654682

655683
# temperature is not an official channel in the header
656684
for i in range(global_info["num_temp_sensor_channels"]):
@@ -680,8 +708,10 @@ def read_rhd(filename, file_format: str):
680708
ordered_channels.append(chan_info)
681709
if file_format == "header-attached":
682710
data_dtype += [(name, "uint16", BLOCK_SIZE)]
711+
elif file_format == 'one-file-per-signal':
712+
data_dtype[3] = "uint16"
683713
else:
684-
data_dtype[3] += [(name, "uint16", BLOCK_SIZE)]
714+
data_dtype[3] += [(name, "uint16")]
685715

686716
# 4: USB board digital input channel
687717
# 5: USB board digital output channel
@@ -699,8 +729,10 @@ def read_rhd(filename, file_format: str):
699729
ordered_channels.append(chan_info)
700730
if file_format == "header-attached":
701731
data_dtype += [(name, "uint16", BLOCK_SIZE)]
732+
elif file_format == "one-file-per-signal":
733+
data_dtype[sig_type] = "uint16"
702734
else:
703-
data_dtype[sig_type] += [(name, "uint16", BLOCK_SIZE)]
735+
data_dtype[sig_type] += [(name, "uint16",)]
704736

705737
if bool(global_info["notch_filter_mode"]) and version >= V("3.0"):
706738
global_info["notch_filter_applied"] = True
@@ -710,8 +742,9 @@ def read_rhd(filename, file_format: str):
710742
if not file_format == "header-attached":
711743
# filter out dtypes without any values
712744
data_dtype = {k:v for (k,v) in data_dtype.items() if len(v) > 0}
745+
channel_number_dict = {k:v for (k,v) in channel_number_dict.items() if v > 0}
713746

714-
return global_info, ordered_channels, data_dtype, header_size, BLOCK_SIZE
747+
return global_info, ordered_channels, data_dtype, header_size, BLOCK_SIZE, channel_number_dict
715748

716749

717750

0 commit comments

Comments
 (0)