Skip to content

Commit f6d85db

Browse files
committed
Implement spikes reading
1 parent 719fc3d commit f6d85db

File tree

2 files changed

+77
-32
lines changed

2 files changed

+77
-32
lines changed

neo/rawio/alphaomegarawio.py

Lines changed: 75 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,8 @@ def _read_file_blocks(self, filename, prune_channels=True):
361361
assert channel_number not in channel_type
362362
channel_type[channel_number] = "segmented_analog"
363363
(
364-
pre_trigm_sec,
365-
post_trigm_sec,
364+
pre_trig_ms,
365+
post_trig_ms,
366366
level_value,
367367
trg_mode,
368368
yes_rms,
@@ -377,8 +377,8 @@ def _read_file_blocks(self, filename, prune_channels=True):
377377
"sample_rate": sample_rate * 1000,
378378
"spike_count": spike_count,
379379
"mode_spike": mode_spike,
380-
"pre_trigm_sec": pre_trigm_sec,
381-
"post_trigm_sec": post_trigm_sec,
380+
"pre_trig_duration": pre_trig_ms / 1000,
381+
"post_trig_duration": post_trig_ms / 1000,
382382
"level_value": level_value,
383383
"trg_mode": trg_mode,
384384
"automatic_level_base_rms": yes_rms,
@@ -772,24 +772,23 @@ def _parse_header(self):
772772
signal_channels.sort(key=lambda x: (x[7], x[0]))
773773
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
774774

775-
# TODO: read the waveforms then uncomment the following
776-
# spike_channels = set(
777-
# (
778-
# c["name"],
779-
# i,
780-
# "uV",
781-
# c["gain"] / c["bit_resolution"],
782-
# 0,
783-
# round(c["pre_trigm_sec"] * c["sample_rate"]),
784-
# c["sample_rate"],
785-
# ) for block in self._blocks
786-
# for segment in block
787-
# for i, c in segment["spikes"].items()
788-
# )
789-
# spike_channels = list(spike_channels)
790-
# spike_channels.sort(key=lambda x: x[0])
791-
# spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
792-
spike_channels = np.array([], dtype=_spike_channel_dtype)
775+
spike_channels = set(
776+
(
777+
c["name"],
778+
i,
779+
"uV",
780+
c["gain"] / c["bit_resolution"],
781+
0,
782+
round(c["pre_trig_duration"] * c["sample_rate"]),
783+
c["sample_rate"],
784+
)
785+
for block in self._blocks
786+
for segment in block
787+
for i, c in segment["spikes"].items()
788+
)
789+
spike_channels = list(spike_channels)
790+
spike_channels.sort(key=lambda x: x[0])
791+
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
793792

794793
event_channels = set(
795794
(event["name"], i, "event")
@@ -994,20 +993,65 @@ def _get_analogsignal_chunk(
994993
return sigs[i_start - min_size : i_stop - min_size, :]
995994

996995
def _spike_count(self, block_index, seg_index, spike_channel_index):
997-
pass
996+
spike_id = int(self.header["spike_channels"]["id"][spike_channel_index])
997+
nb_spikes = sum(
998+
len(f) for f in self._blocks[block_index][seg_index]["spikes"][spike_id]["positions"].values()
999+
)
1000+
return nb_spikes
9981001

9991002
def _get_spike_timestamps(
10001003
self, block_index, seg_index, spike_channel_index, t_start, t_stop
10011004
):
1002-
pass
1005+
if self._spike_count(block_index, seg_index, spike_channel_index):
1006+
spike_id = int(self.header["spike_channels"]["id"][spike_channel_index])
1007+
spikes = self._blocks[block_index][seg_index]["spikes"][spike_id]
1008+
if t_start is None:
1009+
t_start = self._segment_t_start(block_index, seg_index)
1010+
if t_stop is None:
1011+
t_stop = self._segment_t_stop(block_index, seg_index)
1012+
effective_start = t_start * spikes["sample_rate"]
1013+
effective_stop = t_stop * spikes["sample_rate"]
1014+
timestamps = np.array([p[0] for f in spikes["positions"].values() for p in f if effective_start <= p[0] <= effective_stop])
1015+
else:
1016+
timestamps = np.array([], dtype=np.uint32)
1017+
return timestamps
10031018

10041019
def _rescale_spike_timestamp(self, spike_timestamps, dtype):
1005-
pass
1020+
# let's hope every spike channels have the same sampling rate
1021+
sample_rate = int(self.header["spike_channels"]["wf_sampling_rate"][0])
1022+
spike_timestamps = spike_timestamps.astype(dtype) / sample_rate
1023+
return spike_timestamps
10061024

10071025
def _get_spike_raw_waveforms(
10081026
self, block_index, seg_index, spike_channel_index, t_start, t_stop
10091027
):
1010-
pass
1028+
spike_id = int(self.header["spike_channels"]["id"][spike_channel_index])
1029+
# nb_spikes = self._spike_count(block_index, seg_index, spike_channel_index)
1030+
nb_spikes = self._get_spike_timestamps(block_index, seg_index, spike_channel_index, t_start, t_stop).size
1031+
spikes = self._blocks[block_index][seg_index]["spikes"][spike_id]
1032+
spike_length = {p[2] for f in spikes["positions"].values() for p in f}
1033+
assert len(spike_length) == 1
1034+
spike_length = spike_length.pop()
1035+
waveforms = np.ndarray((nb_spikes, spike_length), dtype=np.short)
1036+
if t_start is None:
1037+
t_start = self._segment_t_start(block_index, seg_index)
1038+
if t_stop is None:
1039+
t_stop = self._segment_t_stop(block_index, seg_index)
1040+
effective_start = t_start * spikes["sample_rate"]
1041+
effective_stop = t_stop * spikes["sample_rate"]
1042+
i = 0
1043+
for filename in spikes["positions"]:
1044+
for timestamp, file_position, length in spikes["positions"][filename]:
1045+
if effective_start <= timestamp <= effective_stop:
1046+
waveforms[i, :length] = np.frombuffer(
1047+
self._opened_files[filename]["mmap"],
1048+
dtype=np.short,
1049+
count=length,
1050+
offset=file_position,
1051+
)
1052+
i += 1
1053+
waveforms.shape = nb_spikes, 1, spike_length
1054+
return waveforms
10111055

10121056
def _event_count(self, block_index, seg_index, event_channel_index):
10131057
event_id = int(self.header["event_channels"]["id"][event_channel_index])
@@ -1151,10 +1195,11 @@ def get_name(f, name_length):
11511195
SDefLevelAnalog = struct.Struct("<ffhhhh")
11521196
"""
11531197
Then if mode is Level of Segmented:
1154-
- pre_trigm_sec (float): number of seconds before segment trigger
1155-
- post_trigm_sec (float): number of seconds after segment trigger
1156-
- level_value (short): not sure…
1157-
- trg_mode (short): not sure…
1198+
- pre_trig_msec (float): number of milliseconds before segment trigger
1199+
- post_trig_msec (float): number of milliseconds after segment trigger
1200+
- level_value (short): unknown (should be the level that trigger a
1201+
spike detection)
1202+
- trg_mode (short): unknown (level or template mode?)
11581203
- yes_rms (short): 1 if automatic level calculation base on RMS
11591204
- total_gain_100 (short): see above
11601205
- name (n-char string): channel name; n=length-48

neo/test/rawiotest/test_alphaomegarawio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def test_read_file_blocks(self):
110110
self.assertIn("sample_rate", channel)
111111
self.assertIn("spike_count", channel)
112112
self.assertIn("mode_spike", channel)
113-
self.assertIn("pre_trigm_sec", channel)
114-
self.assertIn("post_trigm_sec", channel)
113+
self.assertIn("pre_trig_duration", channel)
114+
self.assertIn("post_trig_duration", channel)
115115
self.assertIn("level_value", channel)
116116
self.assertIn("trg_mode", channel)
117117
self.assertIn("automatic_level_base_rms", channel)

0 commit comments

Comments
 (0)