@@ -361,8 +361,8 @@ def _read_file_blocks(self, filename, prune_channels=True):
361361 assert channel_number not in channel_type
362362 channel_type [channel_number ] = "segmented_analog"
363363 (
364- pre_trigm_sec ,
365- post_trigm_sec ,
364+ pre_trig_ms ,
365+ post_trig_ms ,
366366 level_value ,
367367 trg_mode ,
368368 yes_rms ,
@@ -377,8 +377,8 @@ def _read_file_blocks(self, filename, prune_channels=True):
377377 "sample_rate" : sample_rate * 1000 ,
378378 "spike_count" : spike_count ,
379379 "mode_spike" : mode_spike ,
380- "pre_trigm_sec " : pre_trigm_sec ,
381- "post_trigm_sec " : post_trigm_sec ,
380+ "pre_trig_duration " : pre_trig_ms / 1000 ,
381+ "post_trig_duration " : post_trig_ms / 1000 ,
382382 "level_value" : level_value ,
383383 "trg_mode" : trg_mode ,
384384 "automatic_level_base_rms" : yes_rms ,
@@ -772,24 +772,23 @@ def _parse_header(self):
772772 signal_channels .sort (key = lambda x : (x [7 ], x [0 ]))
773773 signal_channels = np .array (signal_channels , dtype = _signal_channel_dtype )
774774
775- # TODO: read the waveforms then uncomment the following
776- # spike_channels = set(
777- # (
778- # c["name"],
779- # i,
780- # "uV",
781- # c["gain"] / c["bit_resolution"],
782- # 0,
783- # round(c["pre_trigm_sec"] * c["sample_rate"]),
784- # c["sample_rate"],
785- # ) for block in self._blocks
786- # for segment in block
787- # for i, c in segment["spikes"].items()
788- # )
789- # spike_channels = list(spike_channels)
790- # spike_channels.sort(key=lambda x: x[0])
791- # spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
792- spike_channels = np .array ([], dtype = _spike_channel_dtype )
775+ spike_channels = set (
776+ (
777+ c ["name" ],
778+ i ,
779+ "uV" ,
780+ c ["gain" ] / c ["bit_resolution" ],
781+ 0 ,
782+ round (c ["pre_trig_duration" ] * c ["sample_rate" ]),
783+ c ["sample_rate" ],
784+ )
785+ for block in self ._blocks
786+ for segment in block
787+ for i , c in segment ["spikes" ].items ()
788+ )
789+ spike_channels = list (spike_channels )
790+ spike_channels .sort (key = lambda x : x [0 ])
791+ spike_channels = np .array (spike_channels , dtype = _spike_channel_dtype )
793792
794793 event_channels = set (
795794 (event ["name" ], i , "event" )
@@ -994,20 +993,65 @@ def _get_analogsignal_chunk(
994993 return sigs [i_start - min_size : i_stop - min_size , :]
995994
996995 def _spike_count (self , block_index , seg_index , spike_channel_index ):
997- pass
996+ spike_id = int (self .header ["spike_channels" ]["id" ][spike_channel_index ])
997+ nb_spikes = sum (
998+ len (f ) for f in self ._blocks [block_index ][seg_index ]["spikes" ][spike_id ]["positions" ].values ()
999+ )
1000+ return nb_spikes
9981001
9991002 def _get_spike_timestamps (
10001003 self , block_index , seg_index , spike_channel_index , t_start , t_stop
10011004 ):
1002- pass
1005+ if self ._spike_count (block_index , seg_index , spike_channel_index ):
1006+ spike_id = int (self .header ["spike_channels" ]["id" ][spike_channel_index ])
1007+ spikes = self ._blocks [block_index ][seg_index ]["spikes" ][spike_id ]
1008+ if t_start is None :
1009+ t_start = self ._segment_t_start (block_index , seg_index )
1010+ if t_stop is None :
1011+ t_stop = self ._segment_t_stop (block_index , seg_index )
1012+ effective_start = t_start * spikes ["sample_rate" ]
1013+ effective_stop = t_stop * spikes ["sample_rate" ]
1014+ timestamps = np .array ([p [0 ] for f in spikes ["positions" ].values () for p in f if effective_start <= p [0 ] <= effective_stop ])
1015+ else :
1016+ timestamps = np .array ([], dtype = np .uint32 )
1017+ return timestamps
10031018
10041019 def _rescale_spike_timestamp (self , spike_timestamps , dtype ):
1005- pass
1020+ # let's hope every spike channels have the same sampling rate
1021+ sample_rate = int (self .header ["spike_channels" ]["wf_sampling_rate" ][0 ])
1022+ spike_timestamps = spike_timestamps .astype (dtype ) / sample_rate
1023+ return spike_timestamps
10061024
10071025 def _get_spike_raw_waveforms (
10081026 self , block_index , seg_index , spike_channel_index , t_start , t_stop
10091027 ):
1010- pass
1028+ spike_id = int (self .header ["spike_channels" ]["id" ][spike_channel_index ])
1029+ # nb_spikes = self._spike_count(block_index, seg_index, spike_channel_index)
1030+ nb_spikes = self ._get_spike_timestamps (block_index , seg_index , spike_channel_index , t_start , t_stop ).size
1031+ spikes = self ._blocks [block_index ][seg_index ]["spikes" ][spike_id ]
1032+ spike_length = {p [2 ] for f in spikes ["positions" ].values () for p in f }
1033+ assert len (spike_length ) == 1
1034+ spike_length = spike_length .pop ()
1035+ waveforms = np .ndarray ((nb_spikes , spike_length ), dtype = np .short )
1036+ if t_start is None :
1037+ t_start = self ._segment_t_start (block_index , seg_index )
1038+ if t_stop is None :
1039+ t_stop = self ._segment_t_stop (block_index , seg_index )
1040+ effective_start = t_start * spikes ["sample_rate" ]
1041+ effective_stop = t_stop * spikes ["sample_rate" ]
1042+ i = 0
1043+ for filename in spikes ["positions" ]:
1044+ for timestamp , file_position , length in spikes ["positions" ][filename ]:
1045+ if effective_start <= timestamp <= effective_stop :
1046+ waveforms [i , :length ] = np .frombuffer (
1047+ self ._opened_files [filename ]["mmap" ],
1048+ dtype = np .short ,
1049+ count = length ,
1050+ offset = file_position ,
1051+ )
1052+ i += 1
1053+ waveforms .shape = nb_spikes , 1 , spike_length
1054+ return waveforms
10111055
10121056 def _event_count (self , block_index , seg_index , event_channel_index ):
10131057 event_id = int (self .header ["event_channels" ]["id" ][event_channel_index ])
@@ -1151,10 +1195,11 @@ def get_name(f, name_length):
11511195SDefLevelAnalog = struct .Struct ("<ffhhhh" )
11521196"""
11531197 Then if mode is Level of Segmented:
1154- - pre_trigm_sec (float): number of seconds before segment trigger
1155- - post_trigm_sec (float): number of seconds after segment trigger
1156- - level_value (short): not sure…
1157- - trg_mode (short): not sure…
1198+ - pre_trig_msec (float): number of milliseconds before segment trigger
1199+ - post_trig_msec (float): number of milliseconds after segment trigger
1200+ - level_value (short): unknown (should be the level that trigger a
1201+ spike detection)
1202+ - trg_mode (short): unknown (level or template mode?)
11581203 - yes_rms (short): 1 if automatic level calculation base on RMS
11591204 - total_gain_100 (short): see above
11601205 - name (n-char string): channel name; n=length-48
0 commit comments