2020
2121import matplotlib .pyplot as plt
2222import numpy as np
23+ import pandas as pd
2324from matplotlib .pyplot import Axes
2425from numpy import ndarray
2526from pandas import DataFrame
@@ -58,20 +59,25 @@ def __init__(
5859 start : Optional [float ] = None ,
5960 end : Optional [float ] = None ,
6061 metadata : Dict = {},
62+ careful_spike_processing : bool = False ,
6163 ** kwargs ,
6264 ):
6365 """
6466 Args:
6567 metadata (dict, optional): Metadata for the sweep. Defaults to None.
6668 The metadata can be used to set hyperparameters for features or
6769 store identifying information, such as cell id etc..
70+ careful_spike_processing (bool, optional): Whether to perform spike
71+ processing carefully, i.e. detect pre-, post- and during-stimulus
72+ spikes seperately. Typically leads to less errors, but can be slower.
6873 *args: Additional arguments for EphysSweepFeatureExtractor.
6974 **kwargs: Additional keyword arguments for EphysSweepFeatureExtractor.
7075 """
7176 super ().__init__ (t = t , v = v , i = i , start = start , end = end , ** kwargs )
7277 self .metadata = metadata
7378 self .added_spike_features = {}
7479 self .features = {}
80+ self .careful_spike_processing = careful_spike_processing
7581 self ._init_sweep ()
7682
7783 def _init_sweep (self ):
@@ -86,8 +92,8 @@ def _init_sweep(self):
8692 self .t = self .t [:idx_end ]
8793 self .v = self .v [:idx_end ]
8894 self .i = self .i [:idx_end ]
89- self .start = self .t [0 ]
90- self .end = self .t [- 1 ]
95+ self .start = self .t [0 ] if self . start is None else self . start
96+ self .end = self .t [- 1 ] if self . end is None else self . end
9197
9298 def add_spike_feature (self , feature_name : str , feature_func : Callable ):
9399 """Add a new spike feature to the extractor.
@@ -130,9 +136,55 @@ def _process_added_spike_features(self):
130136 def process_spikes (self ):
131137 """Perform spike-related feature analysis, which includes added spike
132138 features not part of the original AllenSDK implementation."""
133- self ._process_individual_spikes ()
134- self ._process_spike_related_features ()
135- self ._process_added_spike_features ()
139+
140+ def run_spike_processing ():
141+ self ._process_individual_spikes ()
142+ self ._process_spike_related_features ()
143+ self ._process_added_spike_features ()
144+ if not self ._spikes_df .empty :
145+ self ._spikes_df ["T_start" ] = self .start
146+ self ._spikes_df ["T_end" ] = self .end
147+
148+ where_stimulus = self .i != 0
149+ if np .any (where_stimulus ):
150+ stim_onset , stim_end = self .t [where_stimulus ][[0 , - 1 ]]
151+ same_t = lambda t1 , t2 , tol = 1e-3 : (
152+ abs (t1 - t2 ) < tol if t1 is not None and t2 is not None else False
153+ )
154+ else :
155+ stim_onset , stim_end = None , None
156+ same_t = lambda t1 , t2 , tol = 1e-3 : False
157+
158+ if (
159+ same_t (stim_onset , self .start )
160+ and same_t (stim_end , self .end )
161+ or not self .careful_spike_processing
162+ ):
163+ run_spike_processing ()
164+ else :
165+ t_intervals = [self .t [0 ], stim_onset , stim_end , self .t [- 1 ]]
166+ spike_dfs = []
167+ orig_interval = (self .start , self .end )
168+ for t_start , t_end in zip (t_intervals [:- 1 ], t_intervals [1 :]):
169+ self .start = t_start
170+ self .end = t_end
171+ run_spike_processing ()
172+ spike_dfs .append (self ._spikes_df )
173+ del self ._spikes_df
174+
175+ self .start , self .end = orig_interval
176+ self ._spikes_df = pd .concat (spike_dfs )
177+
178+ # remove duplicate spikes at interval boundaries
179+ if not self ._spikes_df .empty :
180+ T_lockout = 1e-3
181+ boundary_idxs = np .where (self ._spikes_df ["T_start" ].diff () > 0 )[0 ]
182+ for idx in boundary_idxs :
183+ ap1_t , ap2_t = self ._spikes_df .iloc [[idx - 1 , idx ]]["threshold_t" ]
184+ if ap1_t + T_lockout > ap2_t :
185+ rm_idx = idx if stim_onset < ap1_t < stim_end else idx - 1
186+ self ._spikes_df .drop (rm_idx , inplace = True )
187+ self ._spikes_df .reset_index (drop = True , inplace = True )
136188
137189 def get_features (self , recompute : bool = False ) -> Dict [str , float ]:
138190 """Compute all features that have been added to the `EphysSweep` instance.
@@ -309,6 +361,7 @@ def __init__(
309361 start : Optional [Union [List , ndarray , float ]] = None ,
310362 end : Optional [Union [List , ndarray , float ]] = None ,
311363 metadata : Dict = {},
364+ careful_spike_processing : bool = False ,
312365 * args ,
313366 ** kwargs ,
314367 ):
@@ -320,6 +373,9 @@ def __init__(
320373 t_start (ndarray, optional): Start time for each sweep.
321374 t_end (ndarray, optional): End time for each sweep.
322375 metadata (dict, optional): Metadata for the sweep set.
376+ careful_spike_processing (bool, optional): Whether to perform spike
377+ processing carefully, i.e. detect pre-, post- and during-stimulus
378+ spikes seperately. Typically leads to less errors, but can be slower.
323379 *args: Additional arguments for EphysSweepSetFeatureExtractor.
324380 **kwargs: Additional keyword arguments for EphysSweepSetFeatureExtractor.
325381 """
@@ -337,6 +393,7 @@ def __init__(
337393 self .metadata = metadata
338394 for sweep in self .sweeps ():
339395 sweep .metadata = metadata
396+ sweep .careful_spike_processing = careful_spike_processing
340397 self .features = {}
341398
342399 @property
0 commit comments