Skip to content

Commit 0b95b37

Browse files
committed
Add read spike times in CedIO
1 parent 33729b5 commit 0b95b37

File tree

1 file changed

+63
-6
lines changed

1 file changed

+63
-6
lines changed

neo/rawio/cedrawio.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,19 @@ def _parse_header(self):
5757

5858
self.smrx_file = sonpy.lib.SonFile(sName=str(self.filename), bReadOnly=True)
5959
smrx = self.smrx_file
60+
61+
self._time_base = smrx.GetTimeBase()
6062

6163
channel_infos = []
6264
signal_channels = []
65+
spike_channels = []
66+
self._all_spike_ticks = {}
67+
6368
for chan_ind in range(smrx.MaxChannels()):
6469
chan_type = smrx.ChannelType(chan_ind)
70+
chan_id = str(chan_ind)
71+
#~ print(chan_type)
72+
#~ continue
6573
if chan_type == sonpy.lib.DataType.Adc:
6674
physical_chan = smrx.PhysicalChannel(chan_ind)
6775
divide = smrx.ChannelDivide(chan_ind)
@@ -78,14 +86,37 @@ def _parse_header(self):
7886
offset = smrx.GetChannelOffset(chan_ind)
7987
units = smrx.GetChannelUnits(chan_ind)
8088
ch_name = smrx.GetChannelTitle(chan_ind)
81-
chan_id = str(chan_ind)
89+
8290
dtype = 'int16'
8391
# set later after grouping
8492
stream_id = '0'
8593
signal_channels.append((ch_name, chan_id, sr, dtype,
8694
units, gain, offset, stream_id))
8795

96+
elif chan_type == sonpy.lib.DataType.AdcMark:
97+
# spike and waveforms : only spike times is used here
98+
ch_name = smrx.GetChannelTitle(chan_ind)
99+
first_time = smrx.FirstTime(chan_ind, 0, max_time)
100+
max_time = smrx.ChannelMaxTime(chan_ind)
101+
divide = smrx.ChannelDivide(chan_ind)
102+
# here we don't use filter (sonpy.lib.MarkerFilter()) so we get all marker
103+
wave_marks = smrx.ReadWaveMarks(chan_ind, int(max_time/divide), 0, max_time)
104+
105+
# here we load in memory all spike once for all because the access is really slow
106+
# with the ReadWaveMarks
107+
spike_ticks = np.array([t.Tick for t in wave_marks])
108+
spike_codes = np.array([t.Code1 for t in wave_marks])
109+
110+
unit_ids = np.unique(spike_codes)
111+
for unit_id in unit_ids:
112+
name = f'{ch_name}#{unit_id}'
113+
spike_chan_id = f'ch{chan_id}#{unit_id}'
114+
spike_channels.append((name, spike_chan_id, '', 1, 0, 0, 0))
115+
mask = spike_codes == unit_id
116+
self._all_spike_ticks[spike_chan_id] = spike_ticks[mask]
117+
88118
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
119+
89120

90121
# channels are grouped into stream if they have a common start, stop, size, divide and sampling_rate
91122
channel_infos = np.array(channel_infos,
@@ -104,8 +135,7 @@ def _parse_header(self):
104135
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
105136

106137
# spike channels not handled
107-
spike_channels = []
108-
spike_channels = np.array([], dtype=_spike_channel_dtype)
138+
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
109139

110140
# event channels not handled
111141
event_channels = []
@@ -115,9 +145,12 @@ def _parse_header(self):
115145
self._seg_t_stop = -np.inf
116146
for info in self.stream_info:
117147
self._seg_t_start = min(self._seg_t_start,
118-
info['first_time'] / info['sampling_rate'])
148+
#~ info['first_time'] / info['sampling_rate'])
149+
info['first_time'] * self._time_base)
150+
119151
self._seg_t_stop = max(self._seg_t_stop,
120-
info['max_time'] / info['sampling_rate'])
152+
#~ info['max_time'] / info['sampling_rate'])
153+
info['max_time'] * self._time_base)
121154

122155
self.header = {}
123156
self.header['nb_block'] = 1
@@ -141,7 +174,8 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
141174

142175
def _get_signal_t_start(self, block_index, seg_index, stream_index):
143176
info = self.stream_info[stream_index]
144-
t_start = info['first_time'] / info['sampling_rate']
177+
#~ t_start = info['first_time'] / info['sampling_rate']
178+
t_start = info['first_time'] * self._time_base
145179
return t_start
146180

147181
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
@@ -175,3 +209,26 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
175209
sigs[:, i] = sig
176210

177211
return sigs
212+
213+
def _spike_count(self, block_index, seg_index, unit_index):
214+
unit_id = self.header['spike_channels'][unit_index]['id']
215+
spike_ticks = self._all_spike_ticks[unit_id]
216+
return spike_ticks.size
217+
218+
219+
def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
220+
unit_id = self.header['spike_channels'][unit_index]['id']
221+
spike_ticks = self._all_spike_ticks[unit_id]
222+
if t_start is not None:
223+
tick_start = int(t_start / self._time_base)
224+
spike_ticks = spike_ticks[spike_ticks >= tick_start]
225+
if t_stop is not None:
226+
tick_stop = int(t_stop / self._time_base)
227+
spike_ticks = spike_ticks[spike_ticks <= tick_stop]
228+
return spike_ticks
229+
230+
def _rescale_spike_timestamp(self, spike_timestamps, dtype):
231+
spike_times = spike_timestamps.astype(dtype)
232+
spike_times *= self._time_base
233+
return spike_times
234+

0 commit comments

Comments
 (0)