@@ -58,10 +58,16 @@ def _parse_header(self):
5858 self .smrx_file = sonpy .lib .SonFile (sName = str (self .filename ), bReadOnly = True )
5959 smrx = self .smrx_file
6060
61+ self ._time_base = smrx .GetTimeBase ()
62+
6163 channel_infos = []
6264 signal_channels = []
65+ spike_channels = []
66+ self ._all_spike_ticks = {}
67+
6368 for chan_ind in range (smrx .MaxChannels ()):
6469 chan_type = smrx .ChannelType (chan_ind )
70+ chan_id = str (chan_ind )
6571 if chan_type == sonpy .lib .DataType .Adc :
6672 physical_chan = smrx .PhysicalChannel (chan_ind )
6773 divide = smrx .ChannelDivide (chan_ind )
@@ -78,13 +84,35 @@ def _parse_header(self):
7884 offset = smrx .GetChannelOffset (chan_ind )
7985 units = smrx .GetChannelUnits (chan_ind )
8086 ch_name = smrx .GetChannelTitle (chan_ind )
81- chan_id = str ( chan_ind )
87+
8288 dtype = 'int16'
8389 # set later after grouping
8490 stream_id = '0'
8591 signal_channels .append ((ch_name , chan_id , sr , dtype ,
8692 units , gain , offset , stream_id ))
8793
94+ elif chan_type == sonpy .lib .DataType .AdcMark :
95+ # spike and waveforms : only spike times is used here
96+ ch_name = smrx .GetChannelTitle (chan_ind )
97+ first_time = smrx .FirstTime (chan_ind , 0 , max_time )
98+ max_time = smrx .ChannelMaxTime (chan_ind )
99+ divide = smrx .ChannelDivide (chan_ind )
100+ # here we don't use filter (sonpy.lib.MarkerFilter()) so we get all marker
101+ wave_marks = smrx .ReadWaveMarks (chan_ind , int (max_time / divide ), 0 , max_time )
102+
103+ # here we load in memory all spike once because the access is really slow
104+ # with the ReadWaveMarks
105+ spike_ticks = np .array ([t .Tick for t in wave_marks ])
106+ spike_codes = np .array ([t .Code1 for t in wave_marks ])
107+
108+ unit_ids = np .unique (spike_codes )
109+ for unit_id in unit_ids :
110+ name = f'{ ch_name } #{ unit_id } '
111+ spike_chan_id = f'ch{ chan_id } #{ unit_id } '
112+ spike_channels .append ((name , spike_chan_id , '' , 1 , 0 , 0 , 0 ))
113+ mask = spike_codes == unit_id
114+ self ._all_spike_ticks [spike_chan_id ] = spike_ticks [mask ]
115+
88116 signal_channels = np .array (signal_channels , dtype = _signal_channel_dtype )
89117
90118 # channels are grouped into stream if they have a common start, stop, size, divide and sampling_rate
@@ -104,8 +132,7 @@ def _parse_header(self):
104132 signal_streams = np .array (signal_streams , dtype = _signal_stream_dtype )
105133
106134 # spike channels not handled
107- spike_channels = []
108- spike_channels = np .array ([], dtype = _spike_channel_dtype )
135+ spike_channels = np .array (spike_channels , dtype = _spike_channel_dtype )
109136
110137 # event channels not handled
111138 event_channels = []
@@ -115,9 +142,10 @@ def _parse_header(self):
115142 self ._seg_t_stop = - np .inf
116143 for info in self .stream_info :
117144 self ._seg_t_start = min (self ._seg_t_start ,
118- info ['first_time' ] / info ['sampling_rate' ])
145+ info ['first_time' ] * self ._time_base )
146+
119147 self ._seg_t_stop = max (self ._seg_t_stop ,
120- info ['max_time' ] / info [ 'sampling_rate' ] )
148+ info ['max_time' ] * self . _time_base )
121149
122150 self .header = {}
123151 self .header ['nb_block' ] = 1
@@ -141,7 +169,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
141169
142170 def _get_signal_t_start (self , block_index , seg_index , stream_index ):
143171 info = self .stream_info [stream_index ]
144- t_start = info ['first_time' ] / info [ 'sampling_rate' ]
172+ t_start = info ['first_time' ] * self . _time_base
145173 return t_start
146174
147175 def _get_analogsignal_chunk (self , block_index , seg_index , i_start , i_stop ,
@@ -175,3 +203,28 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
175203 sigs [:, i ] = sig
176204
177205 return sigs
206+
207+ def _spike_count (self , block_index , seg_index , unit_index ):
208+ unit_id = self .header ['spike_channels' ][unit_index ]['id' ]
209+ spike_ticks = self ._all_spike_ticks [unit_id ]
210+ return spike_ticks .size
211+
212+ def _get_spike_timestamps (self , block_index , seg_index , unit_index , t_start , t_stop ):
213+ unit_id = self .header ['spike_channels' ][unit_index ]['id' ]
214+ spike_ticks = self ._all_spike_ticks [unit_id ]
215+ if t_start is not None :
216+ tick_start = int (t_start / self ._time_base )
217+ spike_ticks = spike_ticks [spike_ticks >= tick_start ]
218+ if t_stop is not None :
219+ tick_stop = int (t_stop / self ._time_base )
220+ spike_ticks = spike_ticks [spike_ticks <= tick_stop ]
221+ return spike_ticks
222+
223+ def _rescale_spike_timestamp (self , spike_timestamps , dtype ):
224+ spike_times = spike_timestamps .astype (dtype )
225+ spike_times *= self ._time_base
226+ return spike_times
227+
228+ def _get_spike_raw_waveforms (self , block_index , seg_index ,
229+ spike_channel_index , t_start , t_stop ):
230+ return None
0 commit comments