2929from ephyspy .features .base import SweepFeature
3030from ephyspy .features .utils import (
3131 FeatureError ,
32+ during_stimulus_only ,
3233 fetch_available_fts ,
3334 get_sweep_burst_metrics ,
3435 has_rebound ,
3839 median_idx ,
3940 sag_idxs ,
4041 sag_period ,
42+ where_spike_during_stimulus ,
4143 where_stimulus ,
4244)
4345from ephyspy .utils import (
@@ -169,7 +171,6 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
169171 def _compute (self , recompute = False , store_diagnostics = True ):
170172 stim_end = float ("nan" )
171173 if has_stimulus (self .data ):
172- stim_type = stimulus_type (self .data )
173174 where_stim = where_stimulus (self .data )
174175 stim_end = self .data .t [where_stim ][- 1 ]
175176 i_end = self .data .i [where_stim ][- 1 ]
@@ -199,10 +200,7 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
199200 super ().__init__ (data , compute_at_init , ** kwargs )
200201
201202 def _compute (self , recompute = False , store_diagnostics = True ):
202- peak_t = self .lookup_spike_feature ("peak_t" , recompute = recompute )
203- onset = self .lookup_sweep_feature ("stim_onset" )
204- end = self .lookup_sweep_feature ("stim_end" )
205- stim_window = where_between (peak_t , onset , end )
203+ stim_window = where_spike_during_stimulus (self , recompute = recompute )
206204
207205 peak_i = self .lookup_spike_feature ("peak_index" )[stim_window ]
208206 num_ap = len (peak_i )
@@ -211,7 +209,7 @@ def _compute(self, recompute=False, store_diagnostics=True):
211209 num_ap = float ("nan" )
212210
213211 if store_diagnostics :
214- peak_t = peak_t [stim_window ]
212+ peak_t = self . lookup_spike_feature ( " peak_t" ) [stim_window ]
215213 peak_v = self .lookup_spike_feature ("peak_v" )[stim_window ]
216214 self ._update_diagnostics (
217215 {
@@ -270,11 +268,10 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
270268 def _compute (self , recompute = False , store_diagnostics = True ):
271269 ap_latency = float ("nan" )
272270 if has_stimulus (self .data ):
273- onset = self .lookup_sweep_feature ("stim_onset" , recompute = recompute )
274- end = self .lookup_sweep_feature ("stim_end" , recompute = recompute )
275- thresh_t = self .lookup_spike_feature ("threshold_t" , recompute = recompute )
276271 thresholds = self .lookup_spike_feature ("threshold_v" , recompute = recompute )
277- stim_window = where_between (thresh_t , onset , end )
272+ thresh_t = self .lookup_spike_feature ("threshold_t" , recompute = recompute )
273+ onset = self .lookup_sweep_feature ("stim_onset" , recompute = recompute )
274+ stim_window = where_spike_during_stimulus (self , recompute = recompute )
278275
279276 thresh_t_stim = thresh_t [stim_window ]
280277
@@ -287,7 +284,6 @@ def _compute(self, recompute=False, store_diagnostics=True):
287284 self ._update_diagnostics (
288285 {
289286 "onset" : onset ,
290- "end" : end ,
291287 "spike_times_during_stim" : thresh_t_stim ,
292288 "t_first" : t_first_spike ,
293289 "v_first" : v_first_spike ,
@@ -488,16 +484,14 @@ def _compute(self, recompute=False, store_diagnostics=True):
488484 ):
489485 onset = self .lookup_sweep_feature ("stim_onset" , recompute = recompute )
490486 end = self .lookup_sweep_feature ("stim_end" , recompute = recompute )
487+ peak_t = self .lookup_spike_feature ("peak_t" , recompute = recompute )
491488 t_half = (end - onset ) / 2 + onset
492489 where_1st_half = where_between (self .data .t , onset , t_half )
493490 where_2nd_half = where_between (self .data .t , t_half , end )
494491 t_1st_half = self .data .t [where_1st_half ]
495492 t_2nd_half = self .data .t [where_2nd_half ]
496493
497- peak_t = self .lookup_spike_feature ("peak_t" , recompute = recompute )
498- onset = self .lookup_sweep_feature ("stim_onset" , recompute = recompute )
499- end = self .lookup_sweep_feature ("stim_end" , recompute = recompute )
500- stim_window = where_between (peak_t , onset , end )
494+ stim_window = where_spike_during_stimulus (self , recompute = recompute )
501495 peak_t = peak_t [stim_window ]
502496
503497 spikes_1st_half = peak_t [peak_t < t_half ]
@@ -542,11 +536,9 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
542536
543537 def _compute (self , recompute = False , store_diagnostics = True ):
544538 ap_amp_slope = float ("nan" )
545- onset = self .lookup_sweep_feature ("stim_onset" )
546- end = self .lookup_sweep_feature ("stim_end" )
547539 peak_t = self .lookup_spike_feature ("peak_t" , recompute = recompute )
548540 peak_v = self .lookup_spike_feature ("peak_v" , recompute = recompute )
549- stim_window = where_between ( peak_t , onset , end )
541+ stim_window = where_spike_during_stimulus ( self , recompute = recompute )
550542
551543 peak_t = peak_t [stim_window ]
552544 peak_v = peak_v [stim_window ]
@@ -1380,25 +1372,26 @@ def _compute(self, recompute=False, store_diagnostics=True):
13801372 num_bursts = float ("nan" )
13811373 num_ap = self .lookup_sweep_feature ("num_ap" , recompute = recompute )
13821374 if num_ap > 5 and has_stimulus (self .data ):
1383- idx_burst , idx_burst_start , idx_burst_end = get_sweep_burst_metrics (
1384- self .data
1385- )
1386- peak_t = self .lookup_spike_feature ("peak_t" , recompute = recompute )
1387- if not np .isnan (idx_burst ).any ():
1388- t_burst_start = peak_t [idx_burst_start ]
1389- t_burst_end = peak_t [idx_burst_end ]
1390- num_bursts = len (idx_burst )
1391- num_bursts = float ("nan" ) if num_bursts == 0 else num_bursts
1392- if store_diagnostics :
1393- self ._update_diagnostics (
1394- {
1395- "idx_burst" : idx_burst ,
1396- "idx_burst_start" : idx_burst_start ,
1397- "idx_burst_end" : idx_burst_end ,
1398- "t_burst_start" : t_burst_start ,
1399- "t_burst_end" : t_burst_end ,
1400- }
1401- )
1375+ with during_stimulus_only (self .data ) as sweep :
1376+ idx_burst , idx_burst_start , idx_burst_end = get_sweep_burst_metrics (
1377+ sweep
1378+ )
1379+ peak_t = self .lookup_spike_feature ("peak_t" , recompute = recompute )
1380+ if not np .isnan (idx_burst ).any ():
1381+ t_burst_start = peak_t [idx_burst_start ]
1382+ t_burst_end = peak_t [idx_burst_end ]
1383+ num_bursts = len (idx_burst )
1384+ num_bursts = float ("nan" ) if num_bursts == 0 else num_bursts
1385+ if store_diagnostics :
1386+ self ._update_diagnostics (
1387+ {
1388+ "idx_burst" : idx_burst ,
1389+ "idx_burst_start" : idx_burst_start ,
1390+ "idx_burst_end" : idx_burst_end ,
1391+ "t_burst_start" : t_burst_start ,
1392+ "t_burst_end" : t_burst_end ,
1393+ }
1394+ )
14021395 return num_bursts
14031396
14041397 def _plot (self , ax : Optional [Axes ] = None , ** kwargs ) -> Axes :
@@ -1430,29 +1423,30 @@ def _compute(self, recompute=False, store_diagnostics=True):
14301423 max_burstiness = float ("nan" )
14311424 num_ap = self .lookup_sweep_feature ("num_ap" , recompute = recompute )
14321425 if num_ap > 5 and has_stimulus (self .data ):
1433- idx_burst , idx_burst_start , idx_burst_end = get_sweep_burst_metrics (
1434- self .data
1435- )
1436- peak_t = self .lookup_spike_feature ("peak_t" , recompute = recompute )
1437- if not np .isnan (idx_burst ).any ():
1438- t_burst_start = peak_t [idx_burst_start ]
1439- t_burst_end = peak_t [idx_burst_end ]
1440- num_bursts = len (idx_burst )
1441- max_burstiness = idx_burst .max () if num_bursts > 0 else float ("nan" )
1442- max_burstiness = (
1443- float ("nan" ) if max_burstiness < 0 else max_burstiness
1444- ) # don't consider negative burstiness
1426+ with during_stimulus_only (self .data ) as sweep :
1427+ idx_burst , idx_burst_start , idx_burst_end = get_sweep_burst_metrics (
1428+ sweep
1429+ )
1430+ peak_t = self .lookup_spike_feature ("peak_t" , recompute = recompute )
1431+ if not np .isnan (idx_burst ).any ():
1432+ t_burst_start = peak_t [idx_burst_start ]
1433+ t_burst_end = peak_t [idx_burst_end ]
1434+ num_bursts = len (idx_burst )
1435+ max_burstiness = idx_burst .max () if num_bursts > 0 else float ("nan" )
1436+ max_burstiness = (
1437+ float ("nan" ) if max_burstiness < 0 else max_burstiness
1438+ ) # don't consider negative burstiness
14451439
1446- if store_diagnostics :
1447- self ._update_diagnostics (
1448- {
1449- "idx_burst" : idx_burst ,
1450- "idx_burst_start" : idx_burst_start ,
1451- "idx_burst_end" : idx_burst_end ,
1452- "t_burst_start" : t_burst_start ,
1453- "t_burst_end" : t_burst_end ,
1454- }
1455- )
1440+ if store_diagnostics :
1441+ self ._update_diagnostics (
1442+ {
1443+ "idx_burst" : idx_burst ,
1444+ "idx_burst_start" : idx_burst_start ,
1445+ "idx_burst_end" : idx_burst_end ,
1446+ "t_burst_start" : t_burst_start ,
1447+ "t_burst_end" : t_burst_end ,
1448+ }
1449+ )
14561450 return max_burstiness
14571451
14581452 def _plot (self , ax : Optional [Axes ] = None , ** kwargs ) -> Axes :
@@ -1510,7 +1504,7 @@ class Sweep_ISI_adapt(SweepFeature):
15101504 """Extract sweep level inter-spike-interval (ISI) adaptation index feature.
15111505
15121506 depends on: ISIs.
1513- description: / .
1507+ description: 1st ISI / 2nd ISI .
15141508 units: /."""
15151509
15161510 def __init__ (self , data = None , compute_at_init = True , ** kwargs ):
@@ -1519,7 +1513,10 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
15191513 def _compute (self , recompute = False , store_diagnostics = True ):
15201514 isi_adapt = float ("nan" )
15211515 if has_spikes (self .data ):
1522- isi = self .lookup_spike_feature ("isi" , recompute = recompute )[1 :]
1516+ isi = self .lookup_spike_feature ("isi" , recompute = recompute )
1517+ during_stim = where_spike_during_stimulus (self , recompute = recompute )
1518+ isi = isi [during_stim ]
1519+ isi = isi [1 :] if isi [0 ] == 0 else isi
15231520 if len (isi ) > 1 :
15241521 isi_adapt = isi [1 ] / isi [0 ]
15251522
@@ -1538,7 +1535,7 @@ class Sweep_ISI_adapt_avg(SweepFeature):
15381535 """Extract sweep level average inter-spike-interval (ISI) adaptation index feature.
15391536
15401537 depends on: ISIs.
1541- description: / .
1538+ description: mean of ISI_{i} / ISI_{i+1} .
15421539 units: /."""
15431540
15441541 def __init__ (self , data = None , compute_at_init = True , ** kwargs ):
@@ -1547,7 +1544,10 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
15471544 def _compute (self , recompute = False , store_diagnostics = True ):
15481545 isi_adapt_avg = float ("nan" )
15491546 if has_spikes (self .data ):
1550- isi = self .lookup_spike_feature ("isi" , recompute = recompute )[1 :]
1547+ isi = self .lookup_spike_feature ("isi" , recompute = recompute )
1548+ during_stim = where_spike_during_stimulus (self , recompute = recompute )
1549+ isi = isi [during_stim ]
1550+ isi = isi [1 :] if isi [0 ] == 0 else isi
15511551 if len (isi ) > 2 :
15521552 isi_changes = isi [1 :] / isi [:- 1 ]
15531553 isi_adapt_avg = isi_changes .mean ()
@@ -1567,7 +1567,7 @@ class Sweep_AP_amp_adapt(SweepFeature):
15671567 """Extract sweep level AP amplitude adaptation index feature.
15681568
15691569 depends on: ap_amp.
1570- description: / .
1570+ description: 1st AP_amp / 2nd AP_amp .
15711571 units: mV/s."""
15721572
15731573 def __init__ (self , data = None , compute_at_init = True , ** kwargs ):
@@ -1577,6 +1577,8 @@ def _compute(self, recompute=False, store_diagnostics=True):
15771577 ap_amp_adapt = float ("nan" )
15781578 if has_spikes (self .data ):
15791579 ap_amp = self .lookup_spike_feature ("ap_amp" , recompute = recompute )
1580+ during_stim = where_spike_during_stimulus (self , recompute = recompute )
1581+ ap_amp = ap_amp [during_stim ]
15801582 if len (ap_amp ) > 1 :
15811583 ap_amp_adapt = ap_amp [1 ] / ap_amp [0 ]
15821584
@@ -1596,7 +1598,7 @@ class Sweep_AP_amp_adapt_avg(SweepFeature):
15961598 """Extract sweep level average AP amplitude adaptation index feature.
15971599
15981600 depends on: ap_amp.
1599- description: / .
1601+ description: mean of AP_amp_{i} / AP_amp_{i+1} .
16001602 units: /."""
16011603
16021604 def __init__ (self , data = None , compute_at_init = True , ** kwargs ):
@@ -1606,6 +1608,8 @@ def _compute(self, recompute=False, store_diagnostics=True):
16061608 ap_amp_adapt_avg = float ("nan" )
16071609 if has_spikes (self .data ):
16081610 ap_amp = self .lookup_spike_feature ("ap_amp" , recompute = recompute )
1611+ during_stim = where_spike_during_stimulus (self , recompute = recompute )
1612+ ap_amp = ap_amp [during_stim ]
16091613 if len (ap_amp ) > 2 :
16101614 ap_amp_changes = ap_amp [1 :] / ap_amp [:- 1 ]
16111615 ap_amp_adapt_avg = ap_amp_changes .mean ()
@@ -1635,12 +1639,10 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
16351639 def _compute (self , recompute = False , store_diagnostics = True ):
16361640 num_wild_spikes = float ("nan" )
16371641 if has_spikes (self .data ):
1638- onset = self .lookup_sweep_feature ("stim_onset" , recompute = recompute )
1639- end = self .lookup_sweep_feature ("stim_end" , recompute = recompute )
16401642 peak_t = self .lookup_spike_feature ("peak_t" , recompute = recompute )
16411643 peak_idx = self .lookup_spike_feature ("peak_index" , recompute = recompute )
16421644 peak_v = self .lookup_spike_feature ("peak_v" , recompute = recompute )
1643- stim_window = where_between ( peak_t , onset , end )
1645+ stim_window = where_spike_during_stimulus ( self , recompute = recompute )
16441646
16451647 idx_wild_spikes = peak_idx [~ stim_window ]
16461648 t_wild_spikes = peak_t [~ stim_window ]
@@ -1703,10 +1705,7 @@ def _select(self, data):
17031705 during stimulus window.
17041706 """
17051707 if self .ap_selector is None :
1706- peak_t = self .lookup_spike_feature ("peak_t" )
1707- onset = self .lookup_sweep_feature ("stim_onset" )
1708- end = self .lookup_sweep_feature ("stim_end" )
1709- stim_window = where_between (peak_t , onset , end )
1708+ stim_window = where_spike_during_stimulus (self )
17101709
17111710 # include sanity check?
17121711 # first = np.array([], dtype=int)
@@ -1950,11 +1949,7 @@ def _select(self, data):
19501949 # ISI_{i} = t_{i} - t_{i-1}
19511950 # therefore ISI_{1} = t_1 - t_0 = t_1 - NaN -> defined as 0
19521951 # first actual ISI is from 1st to 2nd spike, hence the +1
1953- isi = self .lookup_spike_feature ("isi" )
1954- threshold_t = self .lookup_spike_feature ("threshold_t" )
1955- onset = self .lookup_sweep_feature ("stim_onset" )
1956- end = self .lookup_sweep_feature ("stim_end" )
1957- stim_window = where_between (threshold_t , onset , end )
1952+ stim_window = where_spike_during_stimulus (self )
19581953 if np .sum (stim_window ) > 1 :
19591954 return super ()._select (data ) + 1
19601955 return np .array ([], dtype = int )
0 commit comments