@@ -57,11 +57,19 @@ def _parse_header(self):
5757
5858 self .smrx_file = sonpy .lib .SonFile (sName = str (self .filename ), bReadOnly = True )
5959 smrx = self .smrx_file
60+
61+ self ._time_base = smrx .GetTimeBase ()
6062
6163 channel_infos = []
6264 signal_channels = []
65+ spike_channels = []
66+ self ._all_spike_ticks = {}
67+
6368 for chan_ind in range (smrx .MaxChannels ()):
6469 chan_type = smrx .ChannelType (chan_ind )
70+ chan_id = str (chan_ind )
71+ #~ print(chan_type)
72+ #~ continue
6573 if chan_type == sonpy .lib .DataType .Adc :
6674 physical_chan = smrx .PhysicalChannel (chan_ind )
6775 divide = smrx .ChannelDivide (chan_ind )
@@ -78,14 +86,37 @@ def _parse_header(self):
7886 offset = smrx .GetChannelOffset (chan_ind )
7987 units = smrx .GetChannelUnits (chan_ind )
8088 ch_name = smrx .GetChannelTitle (chan_ind )
81- chan_id = str ( chan_ind )
89+
8290 dtype = 'int16'
8391 # set later after grouping
8492 stream_id = '0'
8593 signal_channels .append ((ch_name , chan_id , sr , dtype ,
8694 units , gain , offset , stream_id ))
8795
96+ elif chan_type == sonpy .lib .DataType .AdcMark :
97+ # spike and waveforms : only spike times is used here
98+ ch_name = smrx .GetChannelTitle (chan_ind )
99+ first_time = smrx .FirstTime (chan_ind , 0 , max_time )
100+ max_time = smrx .ChannelMaxTime (chan_ind )
101+ divide = smrx .ChannelDivide (chan_ind )
102+ # here we don't use filter (sonpy.lib.MarkerFilter()) so we get all marker
103+ wave_marks = smrx .ReadWaveMarks (chan_ind , int (max_time / divide ), 0 , max_time )
104+
105+ # here we load in memory all spike once for all because the access is really slow
106+ # with the ReadWaveMarks
107+ spike_ticks = np .array ([t .Tick for t in wave_marks ])
108+ spike_codes = np .array ([t .Code1 for t in wave_marks ])
109+
110+ unit_ids = np .unique (spike_codes )
111+ for unit_id in unit_ids :
112+ name = f'{ ch_name } #{ unit_id } '
113+ spike_chan_id = f'ch{ chan_id } #{ unit_id } '
114+ spike_channels .append ((name , spike_chan_id , '' , 1 , 0 , 0 , 0 ))
115+ mask = spike_codes == unit_id
116+ self ._all_spike_ticks [spike_chan_id ] = spike_ticks [mask ]
117+
88118 signal_channels = np .array (signal_channels , dtype = _signal_channel_dtype )
119+
89120
90121 # channels are grouped into stream if they have a common start, stop, size, divide and sampling_rate
91122 channel_infos = np .array (channel_infos ,
@@ -104,8 +135,7 @@ def _parse_header(self):
104135 signal_streams = np .array (signal_streams , dtype = _signal_stream_dtype )
105136
106137 # spike channels not handled
107- spike_channels = []
108- spike_channels = np .array ([], dtype = _spike_channel_dtype )
138+ spike_channels = np .array (spike_channels , dtype = _spike_channel_dtype )
109139
110140 # event channels not handled
111141 event_channels = []
@@ -115,9 +145,12 @@ def _parse_header(self):
115145 self ._seg_t_stop = - np .inf
116146 for info in self .stream_info :
117147 self ._seg_t_start = min (self ._seg_t_start ,
118- info ['first_time' ] / info ['sampling_rate' ])
148+ #~ info['first_time'] / info['sampling_rate'])
149+ info ['first_time' ] * self ._time_base )
150+
119151 self ._seg_t_stop = max (self ._seg_t_stop ,
120- info ['max_time' ] / info ['sampling_rate' ])
152+ #~ info['max_time'] / info['sampling_rate'])
153+ info ['max_time' ] * self ._time_base )
121154
122155 self .header = {}
123156 self .header ['nb_block' ] = 1
@@ -141,7 +174,8 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
141174
142175 def _get_signal_t_start (self , block_index , seg_index , stream_index ):
143176 info = self .stream_info [stream_index ]
144- t_start = info ['first_time' ] / info ['sampling_rate' ]
177+ #~ t_start = info['first_time'] / info['sampling_rate']
178+ t_start = info ['first_time' ] * self ._time_base
145179 return t_start
146180
147181 def _get_analogsignal_chunk (self , block_index , seg_index , i_start , i_stop ,
@@ -175,3 +209,26 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
175209 sigs [:, i ] = sig
176210
177211 return sigs
212+
213+ def _spike_count (self , block_index , seg_index , unit_index ):
214+ unit_id = self .header ['spike_channels' ][unit_index ]['id' ]
215+ spike_ticks = self ._all_spike_ticks [unit_id ]
216+ return spike_ticks .size
217+
218+
219+ def _get_spike_timestamps (self , block_index , seg_index , unit_index , t_start , t_stop ):
220+ unit_id = self .header ['spike_channels' ][unit_index ]['id' ]
221+ spike_ticks = self ._all_spike_ticks [unit_id ]
222+ if t_start is not None :
223+ tick_start = int (t_start / self ._time_base )
224+ spike_ticks = spike_ticks [spike_ticks >= tick_start ]
225+ if t_stop is not None :
226+ tick_stop = int (t_stop / self ._time_base )
227+ spike_ticks = spike_ticks [spike_ticks <= tick_stop ]
228+ return spike_ticks
229+
230+ def _rescale_spike_timestamp (self , spike_timestamps , dtype ):
231+ spike_times = spike_timestamps .astype (dtype )
232+ spike_times *= self ._time_base
233+ return spike_times
234+
0 commit comments