Skip to content

Commit dbe2e95

Browse files
authored
Merge pull request #1589 from samuelgarcia/micromed_segments
Micromed segments
2 parents 7f6e973 + 6e56ba9 commit dbe2e95

File tree

2 files changed

+117
-38
lines changed

2 files changed

+117
-38
lines changed

neo/rawio/micromedrawio.py

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(self, filename=""):
5252

5353
def _parse_header(self):
5454

55-
self._buffer_descriptions = {0: {0: {}}}
5655

5756
with open(self.filename, "rb") as fid:
5857
f = StructFile(fid)
@@ -67,6 +66,7 @@ def _parse_header(self):
6766
rec_datetime = datetime.datetime(year + 1900, month, day, hour, minute, sec)
6867

6968
Data_Start_Offset, Num_Chan, Multiplexer, Rate_Min, Bytes = f.read_f("IHHHH", offset=138)
69+
sig_dtype = "u" + str(Bytes)
7070

7171
# header version
7272
(header_version,) = f.read_f("b", offset=175)
@@ -99,25 +99,37 @@ def _parse_header(self):
9999
if zname != zname2.decode("ascii").strip(" "):
100100
raise NeoReadWriteError("expected the zone name to match")
101101

102-
# raw signals memmap
103-
sig_dtype = "u" + str(Bytes)
104-
signal_shape = get_memmap_shape(self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset)
105-
buffer_id = "0"
106-
stream_id = "0"
107-
self._buffer_descriptions[0][0][buffer_id] = {
108-
"type": "raw",
109-
"file_path": str(self.filename),
110-
"dtype": sig_dtype,
111-
"order": "C",
112-
"file_offset": 0,
113-
"shape": signal_shape,
114-
}
102+
103+
# "TRONCA" zone define segments
104+
zname2, pos, length = zones["TRONCA"]
105+
f.seek(pos)
106+
# this number avoid a infinite loop in case of corrupted TRONCA zone (seg_start!=0 and trace_offset!=0)
107+
max_segments = 100
108+
self.info_segments = []
109+
for i in range(max_segments):
110+
# 4 bytes u4 each
111+
seg_start = int(np.frombuffer(f.read(4), dtype="u4")[0])
112+
trace_offset = int(np.frombuffer(f.read(4), dtype="u4")[0])
113+
if seg_start == 0 and trace_offset == 0:
114+
break
115+
else:
116+
self.info_segments.append((seg_start, trace_offset))
117+
118+
if len(self.info_segments) == 0:
119+
# one unique segment = general case
120+
self.info_segments.append((0, 0))
121+
122+
nb_segment = len(self.info_segments)
115123

116124
# Reading Code Info
117125
zname2, pos, length = zones["ORDER"]
118126
f.seek(pos)
119127
code = np.frombuffer(f.read(Num_Chan * 2), dtype="u2")
120128

129+
# unique stream and buffer
130+
buffer_id = "0"
131+
stream_id = "0"
132+
121133
units_code = {-1: "nV", 0: "uV", 1: "mV", 2: 1, 100: "percent", 101: "dimensionless", 102: "dimensionless"}
122134
signal_channels = []
123135
sig_grounds = []
@@ -140,10 +152,8 @@ def _parse_header(self):
140152
(sampling_rate,) = f.read_f("H")
141153
sampling_rate *= Rate_Min
142154
chan_id = str(c)
155+
signal_channels.append((chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id))
143156

144-
signal_channels.append(
145-
(chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)
146-
)
147157

148158
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
149159

@@ -155,6 +165,32 @@ def _parse_header(self):
155165
raise NeoReadWriteError("The sampling rates must be the same across signal channels")
156166
self._sampling_rate = float(np.unique(signal_channels["sampling_rate"])[0])
157167

168+
# memmap traces buffer
169+
full_signal_shape = get_memmap_shape(self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset)
170+
seg_limits = [trace_offset for seg_start, trace_offset in self.info_segments] + [full_signal_shape[0]]
171+
self._t_starts = []
172+
self._buffer_descriptions = {0 :{}}
173+
for seg_index in range(nb_segment):
174+
seg_start, trace_offset = self.info_segments[seg_index]
175+
self._t_starts.append(seg_start / self._sampling_rate)
176+
177+
start = seg_limits[seg_index]
178+
stop = seg_limits[seg_index + 1]
179+
180+
shape = (stop - start, Num_Chan)
181+
file_offset = Data_Start_Offset + ( start * np.dtype(sig_dtype).itemsize * Num_Chan)
182+
self._buffer_descriptions[0][seg_index] = {}
183+
self._buffer_descriptions[0][seg_index][buffer_id] = {
184+
"type" : "raw",
185+
"file_path" : str(self.filename),
186+
"dtype" : sig_dtype,
187+
"order": "C",
188+
"file_offset" : file_offset,
189+
"shape" : shape,
190+
}
191+
192+
193+
158194
# Event channels
159195
event_channels = []
160196
event_channels.append(("Trigger", "", "event"))
@@ -176,13 +212,18 @@ def _parse_header(self):
176212
dtype = np.dtype(ev_dtype)
177213
rawevent = np.memmap(self.filename, dtype=dtype, mode="r", offset=pos, shape=length // dtype.itemsize)
178214

179-
keep = (
180-
(rawevent["start"] >= rawevent["start"][0])
181-
& (rawevent["start"] < signal_shape[0])
182-
& (rawevent["start"] != 0)
183-
)
184-
rawevent = rawevent[keep]
185-
self._raw_events.append(rawevent)
215+
# important : all events timing are related to the first segment t_start
216+
self._raw_events.append([])
217+
for seg_index in range(nb_segment):
218+
left_lim = seg_limits[seg_index]
219+
right_lim = seg_limits[seg_index + 1]
220+
keep = (
221+
(rawevent["start"] >= left_lim)
222+
& (rawevent["start"] < right_lim)
223+
& (rawevent["start"] != 0)
224+
)
225+
self._raw_events[-1].append(rawevent[keep])
226+
186227

187228
# No spikes
188229
spike_channels = []
@@ -191,7 +232,7 @@ def _parse_header(self):
191232
# fille into header dict
192233
self.header = {}
193234
self.header["nb_block"] = 1
194-
self.header["nb_segment"] = [1]
235+
self.header["nb_segment"] = [nb_segment]
195236
self.header["signal_buffers"] = signal_buffers
196237
self.header["signal_streams"] = signal_streams
197238
self.header["signal_channels"] = signal_channels
@@ -216,38 +257,40 @@ def _source_name(self):
216257
return self.filename
217258

218259
def _segment_t_start(self, block_index, seg_index):
219-
return 0.0
260+
return self._t_starts[seg_index]
220261

221262
def _segment_t_stop(self, block_index, seg_index):
222-
sig_size = self.get_signal_size(block_index, seg_index, 0)
223-
t_stop = sig_size / self._sampling_rate
224-
return t_stop
263+
duration = self.get_signal_size(block_index, seg_index, stream_index=0) / self._sampling_rate
264+
return duration + self.segment_t_start(block_index, seg_index)
225265

226266
def _get_signal_t_start(self, block_index, seg_index, stream_index):
227-
if stream_index != 0:
228-
raise ValueError("`stream_index` must be 0")
229-
return 0.0
267+
assert stream_index == 0
268+
return self._t_starts[seg_index]
230269

231270
def _spike_count(self, block_index, seg_index, unit_index):
232271
return 0
233272

234273
def _event_count(self, block_index, seg_index, event_channel_index):
235-
n = self._raw_events[event_channel_index].size
274+
n = self._raw_events[event_channel_index][seg_index].size
236275
return n
237276

238277
def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
239278

240-
raw_event = self._raw_events[event_channel_index]
279+
raw_event = self._raw_events[event_channel_index][seg_index]
280+
281+
# important : all events timing are related to the first segment t_start
282+
seg_start0, _ = self.info_segments[0]
241283

242284
if t_start is not None:
243-
keep = raw_event["start"] >= int(t_start * self._sampling_rate)
285+
keep = raw_event["start"] + seg_start0 >= int(t_start * self._sampling_rate)
244286
raw_event = raw_event[keep]
245287

246288
if t_stop is not None:
247-
keep = raw_event["start"] <= int(t_stop * self._sampling_rate)
289+
keep = raw_event["start"] + seg_start0 <= int(t_stop * self._sampling_rate)
248290
raw_event = raw_event[keep]
249291

250-
timestamp = raw_event["start"]
292+
timestamp = raw_event["start"] + seg_start0
293+
251294
if event_channel_index < 2:
252295
durations = None
253296
else:

neo/test/rawiotest/test_micromedrawio.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,50 @@
88

99
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
1010

11+
import numpy as np
1112

1213
class TestMicromedRawIO(
1314
BaseTestRawIO,
1415
unittest.TestCase,
1516
):
1617
rawioclass = MicromedRawIO
1718
entities_to_download = ["micromed"]
18-
entities_to_test = ["micromed/File_micromed_1.TRC"]
19+
entities_to_test = [
20+
"micromed/File_micromed_1.TRC",
21+
"micromed/File_mircomed2.TRC",
22+
"micromed/File_mircomed2_2segments.TRC",
23+
]
24+
25+
def test_micromed_multi_segments(self):
26+
file_full = self.get_local_path("micromed/File_mircomed2.TRC")
27+
file_splitted = self.get_local_path("micromed/File_mircomed2_2segments.TRC")
28+
29+
# the second file contains 2 pieces of the first file
30+
# so it is 2 segments with the same traces but reduced
31+
# note that traces in the splited can differ at the very end of the cut
32+
33+
reader1 = MicromedRawIO(file_full)
34+
reader1.parse_header()
35+
assert reader1.segment_count(block_index=0) == 1
36+
assert reader1.get_signal_t_start(block_index=0, seg_index=0, stream_index=0) == 0.
37+
traces1 = reader1.get_analogsignal_chunk(stream_index=0)
38+
39+
reader2 = MicromedRawIO(file_splitted)
40+
reader2.parse_header()
41+
print(reader2)
42+
assert reader2.segment_count(block_index=0) == 2
43+
44+
# check that pieces of the second file is equal to the first file (except a truncation at the end)
45+
for seg_index in range(2):
46+
t_start = reader2.get_signal_t_start(block_index=0, seg_index=seg_index, stream_index=0)
47+
assert t_start > 0
48+
sr = reader2.get_signal_sampling_rate(stream_index=0)
49+
ind_start = int(t_start * sr)
50+
traces2 = reader2.get_analogsignal_chunk(block_index=0, seg_index=seg_index, stream_index=0)
51+
traces1_chunk = traces1[ind_start: ind_start+traces2.shape[0]]
52+
# we remove the last 100 sample because tools that cut traces is truncating the last buffer
53+
assert np.array_equal(traces2[:-100], traces1_chunk[:-100])
54+
1955

2056

2157
if __name__ == "__main__":

0 commit comments

Comments
 (0)