Skip to content

Commit 20f1159

Browse files
committed
fix: bugfixes
1 parent ae021fb commit 20f1159

File tree

5 files changed

+116
-82
lines changed

5 files changed

+116
-82
lines changed

ephyspy/allen_sdk/ephys_features.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,12 @@ def find_upstroke_indexes(v, t, spike_indexes, peak_indexes, filter=10.0, dvdt=N
203203
if dvdt is None:
204204
dvdt = calculate_dvdt(v, t, filter)
205205

206-
upstroke_indexes = [
207-
np.argmax(dvdt[spike:peak]) + spike
208-
for spike, peak in zip(spike_indexes, peak_indexes)
209-
]
206+
upstroke_indexes = []
207+
for spike, peak in zip(spike_indexes, peak_indexes):
208+
if len(dvdt[spike:peak]) > 0:
209+
upstroke_indexes.append(np.argmax(dvdt[spike:peak]) + spike)
210+
else:
211+
upstroke_indexes.append(np.nan)
210212

211213
return np.array(upstroke_indexes)
212214

@@ -985,7 +987,8 @@ def detect_pauses(isis, isi_types, cost_weight=1.0):
985987
break
986988
cv = non_pause_isis.std() / non_pause_isis.mean()
987989
benefit = all_cv - cv
988-
cost = np.sum(non_pause_isis.std() / np.abs(non_pause_isis.mean() - pause_isis))
990+
div = non_pause_isis.mean() - pause_isis
991+
cost = np.sum(non_pause_isis.std() / np.abs(div if div != 0 else float("nan")))
989992
cost *= cost_weight
990993
net = benefit - cost
991994
if net > 0 and net < best_net:

ephyspy/features/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ def lookup_spike_feature(
291291
):
292292
sweep.process_spikes()
293293

294+
# Allen dataframe does not always have all columns even though the features
295+
# in theory could be computed.
296+
if is_allen_ft and feature_name not in sweep._spikes_df.columns:
297+
return np.array([])
294298
return sweep.spike_feature(feature_name, include_clipped=True)
295299

296300
def get_value(

ephyspy/features/spike_features.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,11 @@ def __init__(self, data=None, **kwargs):
552552

553553
def _compute(self, recompute=False, store_diagnostics=True):
554554
width = self.lookup_spike_feature("width", recompute=recompute)
555-
trough_idxs = self.lookup_spike_feature("trough_index").astype(int)
556-
spike_idxs = self.lookup_spike_feature("threshold_index").astype(int)
557-
peak_idxs = self.lookup_spike_feature("peak_index").astype(int)
555+
556+
# values are floats since they can be nan
557+
trough_idxs = self.lookup_spike_feature("trough_index")
558+
spike_idxs = self.lookup_spike_feature("threshold_index")
559+
peak_idxs = self.lookup_spike_feature("peak_index")
558560

559561
if store_diagnostics:
560562
self._update_diagnostics(
@@ -574,6 +576,9 @@ def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes
574576
trough_idxs, spike_idxs, peak_idxs = unpack(
575577
self.diagnostics, ["trough_idx", "spike_idx", "peak_idx"]
576578
)
579+
trough_idxs = trough_idxs.astype(int)
580+
spike_idxs = spike_idxs.astype(int)
581+
peak_idxs = peak_idxs.astype(int)
577582

578583
t = self.data.t
579584
v = self.data.v

ephyspy/features/sweep_features.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,11 @@ def _compute(self, recompute=False, store_diagnostics=True):
587587
isi_ff = float("nan")
588588
if has_spikes(self.data):
589589
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
590+
during_stim = where_spike_during_stimulus(self, recompute=recompute)[1:]
591+
if np.any(during_stim):
592+
isi = isi[during_stim]
593+
isi = isi[1:] if isi[0] == 0 else isi
594+
590595
if len(isi) > 1:
591596
isi_ff = np.nanvar(isi) / np.nanmean(isi)
592597

@@ -619,6 +624,11 @@ def _compute(self, recompute=False, store_diagnostics=True):
619624
isi_cv = float("nan")
620625
if has_spikes(self.data):
621626
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
627+
during_stim = where_spike_during_stimulus(self, recompute=recompute)[1:]
628+
if np.any(during_stim):
629+
isi = isi[during_stim]
630+
isi = isi[1:] if isi[0] == 0 else isi
631+
622632
if len(isi) > 1:
623633
isi_cv = np.nanstd(isi) / np.nanmean(isi)
624634

@@ -1513,15 +1523,16 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
15131523
def _compute(self, recompute=False, store_diagnostics=True):
15141524
isi_adapt = float("nan")
15151525
if has_spikes(self.data):
1516-
isi = self.lookup_spike_feature("isi", recompute=recompute)
1517-
during_stim = where_spike_during_stimulus(self, recompute=recompute)
1518-
isi = isi[during_stim]
1519-
isi = isi[1:] if isi[0] == 0 else isi
1520-
if len(isi) > 1:
1521-
isi_adapt = isi[1] / isi[0]
1526+
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
1527+
during_stim = where_spike_during_stimulus(self, recompute=recompute)[1:]
1528+
if np.any(during_stim):
1529+
isi = isi[during_stim]
1530+
isi = isi[1:] if isi[0] == 0 else isi
1531+
if len(isi) > 1:
1532+
isi_adapt = isi[1] / isi[0]
15221533

1523-
if store_diagnostics:
1524-
self._update_diagnostics({"isi": isi})
1534+
if store_diagnostics:
1535+
self._update_diagnostics({"isi": isi})
15251536
return isi_adapt
15261537

15271538
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
@@ -1544,16 +1555,17 @@ def __init__(self, data=None, compute_at_init=True, **kwargs):
15441555
def _compute(self, recompute=False, store_diagnostics=True):
15451556
isi_adapt_avg = float("nan")
15461557
if has_spikes(self.data):
1547-
isi = self.lookup_spike_feature("isi", recompute=recompute)
1548-
during_stim = where_spike_during_stimulus(self, recompute=recompute)
1549-
isi = isi[during_stim]
1550-
isi = isi[1:] if isi[0] == 0 else isi
1551-
if len(isi) > 2:
1552-
isi_changes = isi[1:] / isi[:-1]
1553-
isi_adapt_avg = isi_changes.mean()
1558+
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
1559+
during_stim = where_spike_during_stimulus(self, recompute=recompute)[1:]
1560+
if np.any(during_stim):
1561+
isi = isi[during_stim]
1562+
isi = isi[1:] if isi[0] == 0 else isi
1563+
if len(isi) > 2:
1564+
isi_changes = isi[1:] / isi[:-1]
1565+
isi_adapt_avg = isi_changes.mean()
15541566

1555-
if store_diagnostics:
1556-
self._update_diagnostics({"isi": isi})
1567+
if store_diagnostics:
1568+
self._update_diagnostics({"isi": isi})
15571569
return isi_adapt_avg
15581570

15591571
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:

ephyspy/features/sweepset_features.py

Lines changed: 67 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def _select(self, fts):
9999
peaks_to_low = np.all(self.lookup_sweep_feature("ap_peak") < -30)
100100

101101
if not peaks_to_low:
102-
idx = np.where(is_depol & has_spikes)[0][0]
102+
idx = np.where(is_depol & has_spikes)[0]
103+
idx = idx[0] if len(idx) > 0 else np.array([], dtype=int)
103104
else:
104105
idx = np.array([], dtype=int)
105106
elif stimulus_type(self.data) == "ramp":
@@ -199,8 +200,9 @@ def _select(self, fts):
199200
num_spikes = self.lookup_sweep_feature("num_ap")
200201
wildness = self.lookup_sweep_feature("wildness")
201202
is_non_wild = np.isnan(wildness)
202-
idx = pd.Series(num_spikes)[is_non_wild].idxmax()
203-
idx = np.array([], dtype=int) if np.isnan(idx) else idx
203+
num_non_wild_spikes = pd.Series(num_spikes)[is_non_wild]
204+
no_spikes = np.all(np.isnan(num_non_wild_spikes))
205+
idx = np.array([], dtype=int) if no_spikes else num_non_wild_spikes.idxmax()
204206

205207
self._update_diagnostics(
206208
{
@@ -459,31 +461,33 @@ def _compute(self, recompute=False, store_diagnostics=False):
459461
i = stim_amp[is_depol]
460462

461463
has_spikes = ~np.isnan(f)
462-
# sometimes all depolarization traces spike
463-
i_sub = (
464-
0 if all(has_spikes) else i[~has_spikes][0]
465-
) # last stim < spike threshold
466-
i_sup = i[has_spikes][0] # first stim > spike threshold
464+
if np.any(has_spikes):
465+
# sometimes all depolarization traces spike
466+
i_sub = (
467+
0 if all(has_spikes) else i[~has_spikes][0]
468+
) # last stim < spike threshold
469+
i_sup = i[has_spikes][0] # first stim > spike threshold
470+
471+
if not np.isnan(dfdi):
472+
rheobase = float(ransac.predict(np.array([[0]]))) / dfdi
473+
474+
if rheobase < i_sub or rheobase > i_sup:
475+
rheobase = i_sup
476+
else:
477+
rheobase = i_sup
478+
rheobase -= dc_offset
467479

468-
if not np.isnan(dfdi):
469-
rheobase = float(ransac.predict(np.array([[0]]))) / dfdi
480+
if store_diagnostics:
481+
self._update_diagnostics(
482+
{
483+
"i_sub": i_sub,
484+
"i_sup": i_sup,
485+
"f_sup": f[has_spikes][0],
486+
"dfdi": dfdi,
487+
"dc_offset": dc_offset,
488+
}
489+
)
470490

471-
if rheobase < i_sub or rheobase > i_sup:
472-
rheobase = i_sup
473-
else:
474-
rheobase = i_sup
475-
rheobase -= dc_offset
476-
477-
if store_diagnostics:
478-
self._update_diagnostics(
479-
{
480-
"i_sub": i_sub,
481-
"i_sup": i_sup,
482-
"f_sup": f[has_spikes][0],
483-
"dfdi": dfdi,
484-
"dc_offset": dc_offset,
485-
}
486-
)
487491
if stimulus_type(self.data) == "ramp":
488492
has_ap = self.lookup_sweep_feature("num_ap", recompute=recompute) > 0
489493
if np.any(has_ap):
@@ -685,19 +689,22 @@ def _compute(self, recompute=False, store_diagnostics=False):
685689
slow_hyperpolarization = float("nan")
686690
if stimulus_type(self.data) == "long_square":
687691
has_aps = self.lookup_sweep_feature("num_ap", recompute=recompute) > 0
688-
v_baseline = self.lookup_sweep_feature("v_baseline", recompute=recompute)
689-
v_baseline = v_baseline[has_aps]
690-
691-
slow_hyperpolarization = v_baseline.max() - v_baseline.min()
692-
693-
if store_diagnostics:
694-
self._update_diagnostics(
695-
{
696-
"v_baseline": v_baseline,
697-
"v_baseline_max": v_baseline.max(),
698-
"v_baseline_min": v_baseline.min(),
699-
}
692+
if np.any(has_aps):
693+
v_baseline = self.lookup_sweep_feature(
694+
"v_baseline", recompute=recompute
700695
)
696+
v_baseline = v_baseline[has_aps]
697+
698+
slow_hyperpolarization = v_baseline.max() - v_baseline.min()
699+
700+
if store_diagnostics:
701+
self._update_diagnostics(
702+
{
703+
"v_baseline": v_baseline,
704+
"v_baseline_max": v_baseline.max(),
705+
"v_baseline_min": v_baseline.min(),
706+
}
707+
)
701708
return slow_hyperpolarization
702709

703710
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
@@ -730,26 +737,29 @@ def _compute(self, recompute=False, store_diagnostics=False):
730737
slow_hyperpolarization_slope = float("nan")
731738
if stimulus_type(self.data) == "long_square":
732739
has_aps = self.lookup_sweep_feature("num_ap", recompute=recompute) > 0
733-
v_baseline = self.lookup_sweep_feature("v_baseline", recompute=recompute)
734-
v_baseline = v_baseline[has_aps]
735-
736-
v_baseline = v_baseline.reshape(-1, 1)
737-
sweep_idx = np.arange(len(v_baseline)).reshape(-1, 1)
738-
739-
if len(v_baseline) >= 3:
740-
ransac.fit(sweep_idx, v_baseline)
741-
slope = ransac.coef_[0, 0] * 1000
742-
intercept = ransac.intercept_[0]
743-
slow_hyperpolarization_slope = slope
744-
745-
if store_diagnostics:
746-
self._update_diagnostics(
747-
{
748-
"v_baseline": v_baseline,
749-
"sweep_idx": sweep_idx,
750-
"v_intercept": intercept,
751-
}
740+
if np.any(has_aps):
741+
v_baseline = self.lookup_sweep_feature(
742+
"v_baseline", recompute=recompute
752743
)
744+
v_baseline = v_baseline[has_aps]
745+
746+
v_baseline = v_baseline.reshape(-1, 1)
747+
sweep_idx = np.arange(len(v_baseline)).reshape(-1, 1)
748+
749+
if len(v_baseline) >= 3:
750+
ransac.fit(sweep_idx, v_baseline)
751+
slope = ransac.coef_[0, 0] * 1000
752+
intercept = ransac.intercept_[0]
753+
slow_hyperpolarization_slope = slope
754+
755+
if store_diagnostics:
756+
self._update_diagnostics(
757+
{
758+
"v_baseline": v_baseline,
759+
"sweep_idx": sweep_idx,
760+
"v_intercept": intercept,
761+
}
762+
)
753763
return slow_hyperpolarization_slope
754764

755765
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:

0 commit comments

Comments
 (0)