Skip to content

Commit b610ef4

Browse files
committed
wip: refactor and small changes to ISI and other features to exclude spikes outisde of stimulus
1 parent 14a4655 commit b610ef4

File tree

4 files changed

+140
-91
lines changed

4 files changed

+140
-91
lines changed

ephyspy/features/sweep_features.py

Lines changed: 72 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ephyspy.features.base import SweepFeature
3030
from ephyspy.features.utils import (
3131
FeatureError,
32+
during_stimulus_only,
3233
fetch_available_fts,
3334
get_sweep_burst_metrics,
3435
has_rebound,
@@ -38,6 +39,7 @@
3839
median_idx,
3940
sag_idxs,
4041
sag_period,
42+
where_spike_during_stimulus,
4143
where_stimulus,
4244
)
4345
from ephyspy.utils import (
@@ -169,7 +171,6 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
169171
def _compute(self, recompute=False, store_diagnostics=True):
170172
stim_end = float("nan")
171173
if has_stimulus(self.data):
172-
stim_type = stimulus_type(self.data)
173174
where_stim = where_stimulus(self.data)
174175
stim_end = self.data.t[where_stim][-1]
175176
i_end = self.data.i[where_stim][-1]
@@ -199,10 +200,7 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
199200
super().__init__(data, compute_at_init, **kwargs)
200201

201202
def _compute(self, recompute=False, store_diagnostics=True):
202-
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
203-
onset = self.lookup_sweep_feature("stim_onset")
204-
end = self.lookup_sweep_feature("stim_end")
205-
stim_window = where_between(peak_t, onset, end)
203+
stim_window = where_spike_during_stimulus(self, recompute=recompute)
206204

207205
peak_i = self.lookup_spike_feature("peak_index")[stim_window]
208206
num_ap = len(peak_i)
@@ -211,7 +209,7 @@ def _compute(self, recompute=False, store_diagnostics=True):
211209
num_ap = float("nan")
212210

213211
if store_diagnostics:
214-
peak_t = peak_t[stim_window]
212+
peak_t = self.lookup_spike_feature("peak_t")[stim_window]
215213
peak_v = self.lookup_spike_feature("peak_v")[stim_window]
216214
self._update_diagnostics(
217215
{
@@ -270,11 +268,10 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
270268
def _compute(self, recompute=False, store_diagnostics=True):
271269
ap_latency = float("nan")
272270
if has_stimulus(self.data):
273-
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
274-
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
275-
thresh_t = self.lookup_spike_feature("threshold_t", recompute=recompute)
276271
thresholds = self.lookup_spike_feature("threshold_v", recompute=recompute)
277-
stim_window = where_between(thresh_t, onset, end)
272+
thresh_t = self.lookup_spike_feature("threshold_t", recompute=recompute)
273+
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
274+
stim_window = where_spike_during_stimulus(self, recompute=recompute)
278275

279276
thresh_t_stim = thresh_t[stim_window]
280277

@@ -287,7 +284,6 @@ def _compute(self, recompute=False, store_diagnostics=True):
287284
self._update_diagnostics(
288285
{
289286
"onset": onset,
290-
"end": end,
291287
"spike_times_during_stim": thresh_t_stim,
292288
"t_first": t_first_spike,
293289
"v_first": v_first_spike,
@@ -488,16 +484,14 @@ def _compute(self, recompute=False, store_diagnostics=True):
488484
):
489485
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
490486
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
487+
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
491488
t_half = (end - onset) / 2 + onset
492489
where_1st_half = where_between(self.data.t, onset, t_half)
493490
where_2nd_half = where_between(self.data.t, t_half, end)
494491
t_1st_half = self.data.t[where_1st_half]
495492
t_2nd_half = self.data.t[where_2nd_half]
496493

497-
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
498-
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
499-
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
500-
stim_window = where_between(peak_t, onset, end)
494+
stim_window = where_spike_during_stimulus(self, recompute=recompute)
501495
peak_t = peak_t[stim_window]
502496

503497
spikes_1st_half = peak_t[peak_t < t_half]
@@ -542,11 +536,9 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
542536

543537
def _compute(self, recompute=False, store_diagnostics=True):
544538
ap_amp_slope = float("nan")
545-
onset = self.lookup_sweep_feature("stim_onset")
546-
end = self.lookup_sweep_feature("stim_end")
547539
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
548540
peak_v = self.lookup_spike_feature("peak_v", recompute=recompute)
549-
stim_window = where_between(peak_t, onset, end)
541+
stim_window = where_spike_during_stimulus(self, recompute=recompute)
550542

551543
peak_t = peak_t[stim_window]
552544
peak_v = peak_v[stim_window]
@@ -1380,25 +1372,26 @@ def _compute(self, recompute=False, store_diagnostics=True):
13801372
num_bursts = float("nan")
13811373
num_ap = self.lookup_sweep_feature("num_ap", recompute=recompute)
13821374
if num_ap > 5 and has_stimulus(self.data):
1383-
idx_burst, idx_burst_start, idx_burst_end = get_sweep_burst_metrics(
1384-
self.data
1385-
)
1386-
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
1387-
if not np.isnan(idx_burst).any():
1388-
t_burst_start = peak_t[idx_burst_start]
1389-
t_burst_end = peak_t[idx_burst_end]
1390-
num_bursts = len(idx_burst)
1391-
num_bursts = float("nan") if num_bursts == 0 else num_bursts
1392-
if store_diagnostics:
1393-
self._update_diagnostics(
1394-
{
1395-
"idx_burst": idx_burst,
1396-
"idx_burst_start": idx_burst_start,
1397-
"idx_burst_end": idx_burst_end,
1398-
"t_burst_start": t_burst_start,
1399-
"t_burst_end": t_burst_end,
1400-
}
1401-
)
1375+
with during_stimulus_only(self.data) as sweep:
1376+
idx_burst, idx_burst_start, idx_burst_end = get_sweep_burst_metrics(
1377+
sweep
1378+
)
1379+
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
1380+
if not np.isnan(idx_burst).any():
1381+
t_burst_start = peak_t[idx_burst_start]
1382+
t_burst_end = peak_t[idx_burst_end]
1383+
num_bursts = len(idx_burst)
1384+
num_bursts = float("nan") if num_bursts == 0 else num_bursts
1385+
if store_diagnostics:
1386+
self._update_diagnostics(
1387+
{
1388+
"idx_burst": idx_burst,
1389+
"idx_burst_start": idx_burst_start,
1390+
"idx_burst_end": idx_burst_end,
1391+
"t_burst_start": t_burst_start,
1392+
"t_burst_end": t_burst_end,
1393+
}
1394+
)
14021395
return num_bursts
14031396

14041397
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
@@ -1430,29 +1423,30 @@ def _compute(self, recompute=False, store_diagnostics=True):
14301423
max_burstiness = float("nan")
14311424
num_ap = self.lookup_sweep_feature("num_ap", recompute=recompute)
14321425
if num_ap > 5 and has_stimulus(self.data):
1433-
idx_burst, idx_burst_start, idx_burst_end = get_sweep_burst_metrics(
1434-
self.data
1435-
)
1436-
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
1437-
if not np.isnan(idx_burst).any():
1438-
t_burst_start = peak_t[idx_burst_start]
1439-
t_burst_end = peak_t[idx_burst_end]
1440-
num_bursts = len(idx_burst)
1441-
max_burstiness = idx_burst.max() if num_bursts > 0 else float("nan")
1442-
max_burstiness = (
1443-
float("nan") if max_burstiness < 0 else max_burstiness
1444-
) # don't consider negative burstiness
1426+
with during_stimulus_only(self.data) as sweep:
1427+
idx_burst, idx_burst_start, idx_burst_end = get_sweep_burst_metrics(
1428+
sweep
1429+
)
1430+
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
1431+
if not np.isnan(idx_burst).any():
1432+
t_burst_start = peak_t[idx_burst_start]
1433+
t_burst_end = peak_t[idx_burst_end]
1434+
num_bursts = len(idx_burst)
1435+
max_burstiness = idx_burst.max() if num_bursts > 0 else float("nan")
1436+
max_burstiness = (
1437+
float("nan") if max_burstiness < 0 else max_burstiness
1438+
) # don't consider negative burstiness
14451439

1446-
if store_diagnostics:
1447-
self._update_diagnostics(
1448-
{
1449-
"idx_burst": idx_burst,
1450-
"idx_burst_start": idx_burst_start,
1451-
"idx_burst_end": idx_burst_end,
1452-
"t_burst_start": t_burst_start,
1453-
"t_burst_end": t_burst_end,
1454-
}
1455-
)
1440+
if store_diagnostics:
1441+
self._update_diagnostics(
1442+
{
1443+
"idx_burst": idx_burst,
1444+
"idx_burst_start": idx_burst_start,
1445+
"idx_burst_end": idx_burst_end,
1446+
"t_burst_start": t_burst_start,
1447+
"t_burst_end": t_burst_end,
1448+
}
1449+
)
14561450
return max_burstiness
14571451

14581452
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
@@ -1510,7 +1504,7 @@ class Sweep_ISI_adapt(SweepFeature):
15101504
"""Extract sweep level inter-spike-interval (ISI) adaptation index feature.
15111505
15121506
depends on: ISIs.
1513-
description: /.
1507+
description: 1st ISI / 2nd ISI.
15141508
units: /."""
15151509

15161510
def __init__(self, data=None, compute_at_init=True, **kwargs):
@@ -1519,7 +1513,10 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
15191513
def _compute(self, recompute=False, store_diagnostics=True):
15201514
isi_adapt = float("nan")
15211515
if has_spikes(self.data):
1522-
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
1516+
isi = self.lookup_spike_feature("isi", recompute=recompute)
1517+
during_stim = where_spike_during_stimulus(self, recompute=recompute)
1518+
isi = isi[during_stim]
1519+
isi = isi[1:] if isi[0] == 0 else isi
15231520
if len(isi) > 1:
15241521
isi_adapt = isi[1] / isi[0]
15251522

@@ -1538,7 +1535,7 @@ class Sweep_ISI_adapt_avg(SweepFeature):
15381535
"""Extract sweep level average inter-spike-interval (ISI) adaptation index feature.
15391536
15401537
depends on: ISIs.
1541-
description: /.
1538+
description: mean of ISI_{i} / ISI_{i+1}.
15421539
units: /."""
15431540

15441541
def __init__(self, data=None, compute_at_init=True, **kwargs):
@@ -1547,7 +1544,10 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
15471544
def _compute(self, recompute=False, store_diagnostics=True):
15481545
isi_adapt_avg = float("nan")
15491546
if has_spikes(self.data):
1550-
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
1547+
isi = self.lookup_spike_feature("isi", recompute=recompute)
1548+
during_stim = where_spike_during_stimulus(self, recompute=recompute)
1549+
isi = isi[during_stim]
1550+
isi = isi[1:] if isi[0] == 0 else isi
15511551
if len(isi) > 2:
15521552
isi_changes = isi[1:] / isi[:-1]
15531553
isi_adapt_avg = isi_changes.mean()
@@ -1567,7 +1567,7 @@ class Sweep_AP_amp_adapt(SweepFeature):
15671567
"""Extract sweep level AP amplitude adaptation index feature.
15681568
15691569
depends on: ap_amp.
1570-
description: /.
1570+
description: 1st AP_amp / 2nd AP_amp.
15711571
units: mV/s."""
15721572

15731573
def __init__(self, data=None, compute_at_init=True, **kwargs):
@@ -1577,6 +1577,8 @@ def _compute(self, recompute=False, store_diagnostics=True):
15771577
ap_amp_adapt = float("nan")
15781578
if has_spikes(self.data):
15791579
ap_amp = self.lookup_spike_feature("ap_amp", recompute=recompute)
1580+
during_stim = where_spike_during_stimulus(self, recompute=recompute)
1581+
ap_amp = ap_amp[during_stim]
15801582
if len(ap_amp) > 1:
15811583
ap_amp_adapt = ap_amp[1] / ap_amp[0]
15821584

@@ -1596,7 +1598,7 @@ class Sweep_AP_amp_adapt_avg(SweepFeature):
15961598
"""Extract sweep level average AP amplitude adaptation index feature.
15971599
15981600
depends on: ap_amp.
1599-
description: /.
1601+
description: mean of AP_amp_{i} / AP_amp_{i+1}.
16001602
units: /."""
16011603

16021604
def __init__(self, data=None, compute_at_init=True, **kwargs):
@@ -1606,6 +1608,8 @@ def _compute(self, recompute=False, store_diagnostics=True):
16061608
ap_amp_adapt_avg = float("nan")
16071609
if has_spikes(self.data):
16081610
ap_amp = self.lookup_spike_feature("ap_amp", recompute=recompute)
1611+
during_stim = where_spike_during_stimulus(self, recompute=recompute)
1612+
ap_amp = ap_amp[during_stim]
16091613
if len(ap_amp) > 2:
16101614
ap_amp_changes = ap_amp[1:] / ap_amp[:-1]
16111615
ap_amp_adapt_avg = ap_amp_changes.mean()
@@ -1635,12 +1639,10 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
16351639
def _compute(self, recompute=False, store_diagnostics=True):
16361640
num_wild_spikes = float("nan")
16371641
if has_spikes(self.data):
1638-
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
1639-
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
16401642
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
16411643
peak_idx = self.lookup_spike_feature("peak_index", recompute=recompute)
16421644
peak_v = self.lookup_spike_feature("peak_v", recompute=recompute)
1643-
stim_window = where_between(peak_t, onset, end)
1645+
stim_window = where_spike_during_stimulus(self, recompute=recompute)
16441646

16451647
idx_wild_spikes = peak_idx[~stim_window]
16461648
t_wild_spikes = peak_t[~stim_window]
@@ -1703,10 +1705,7 @@ def _select(self, data):
17031705
during stimulus window.
17041706
"""
17051707
if self.ap_selector is None:
1706-
peak_t = self.lookup_spike_feature("peak_t")
1707-
onset = self.lookup_sweep_feature("stim_onset")
1708-
end = self.lookup_sweep_feature("stim_end")
1709-
stim_window = where_between(peak_t, onset, end)
1708+
stim_window = where_spike_during_stimulus(self)
17101709

17111710
# include sanity check?
17121711
# first = np.array([], dtype=int)
@@ -1950,11 +1949,7 @@ def _select(self, data):
19501949
# ISI_{i} = t_{i} - t_{i-1}
19511950
# therefore ISI_{1} = t_1 - t_0 = t_1 - NaN -> defined as 0
19521951
# first actual ISI is from 1st to 2nd spike, hence the +1
1953-
isi = self.lookup_spike_feature("isi")
1954-
threshold_t = self.lookup_spike_feature("threshold_t")
1955-
onset = self.lookup_sweep_feature("stim_onset")
1956-
end = self.lookup_sweep_feature("stim_end")
1957-
stim_window = where_between(threshold_t, onset, end)
1952+
stim_window = where_spike_during_stimulus(self)
19581953
if np.sum(stim_window) > 1:
19591954
return super()._select(data) + 1
19601955
return np.array([], dtype=int)

ephyspy/features/sweepset_features.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333

3434

3535
from ephyspy.features.base import SweepSetFeature
36-
from ephyspy.features.utils import fetch_available_fts, median_idx
36+
from ephyspy.features.utils import (
37+
fetch_available_fts,
38+
median_idx,
39+
where_spike_during_stimulus,
40+
)
3741
from ephyspy.utils import is_sweepset_feature
3842

3943

@@ -482,12 +486,9 @@ def _compute(self, recompute=False, store_diagnostics=False):
482486
sweep_idx = np.where(has_ap)[0][0]
483487

484488
spike_df = self.data[sweep_idx]._spikes_df
485-
threshold_t = spike_df["threshold_t"]
486489

487490
# sweep has ap during stimulus
488-
onset = self.lookup_sweep_feature("stim_onset")[sweep_idx]
489-
end = self.lookup_sweep_feature("stim_end")[sweep_idx]
490-
stim_window = where_between(threshold_t.to_numpy(), onset, end)
491+
stim_window = where_spike_during_stimulus(self, recompute=recompute)
491492

492493
if np.any(stim_window):
493494
first_spike = spike_df[stim_window].iloc[0]

0 commit comments

Comments
 (0)