Skip to content

Commit 7fdc561

Browse files
committed
enh: Make AP features recursive
1 parent 7509350 commit 7fdc561

File tree

4 files changed

+104
-68
lines changed

4 files changed

+104
-68
lines changed

ephyspy/features/base.py

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@
2424
from matplotlib.pyplot import Axes
2525
from numpy import ndarray
2626

27-
from ephyspy.features.utils import FeatureError, fetch_available_fts
27+
from ephyspy.features.utils import (
28+
FeatureError,
29+
fetch_available_fts,
30+
get_allensdk_spike_features,
31+
)
2832
from ephyspy.sweeps import EphysSweep, EphysSweepSet
2933
from ephyspy.utils import (
34+
is_spike_feature,
3035
is_sweep_feature,
3136
is_sweepset_feature,
3237
parse_deps,
@@ -239,6 +244,55 @@ def _update_diagnostics(self, dct: Dict[str, Any]):
239244
self._diagnostics = {}
240245
self._diagnostics.update(dct)
241246

247+
def lookup_spike_feature(
248+
self, feature_name: str, recompute: bool = False
249+
) -> ndarray:
250+
"""Look up a spike level feature and return its value.
251+
252+
This method will first check if the feature is already computed,
253+
and if not, compute all spike level features using `process_spikes` from
254+
the underlying `EphysSweep` object, and then
255+
instantiate and compute the feature. Lookup is recursive and considers all
256+
registered and implemented spike features.
257+
258+
Args:
259+
feature_name: Name of the feature to look up.
260+
recompute: If True, recompute the feature even if it is already
261+
computed.
262+
263+
Returns:
264+
The value of the feature for each detected spike.
265+
266+
Raises:
267+
FeatureError: If the feature is not found via `fetch_available_fts`.
268+
"""
269+
sweep = self.data
270+
is_allen_ft = feature_name in get_allensdk_spike_features()
271+
ft_already_added = feature_name in sweep.added_spike_features
272+
if not (is_allen_ft or ft_already_added):
273+
available_fts = fetch_available_fts()
274+
available_fts = [ft for ft in available_fts if is_spike_feature(ft)]
275+
available_fts = {
276+
ft.__name__.lower().replace("spike_", ""): ft for ft in available_fts
277+
}
278+
if feature_name in available_fts:
279+
feature = available_fts[feature_name](sweep, compute_at_init=False)
280+
sweep.add_spike_feature(feature.name, feature)
281+
else:
282+
raise FeatureError(
283+
f"{feature_name} was not found. Make sure it is implemented or registered."
284+
)
285+
286+
if not hasattr(sweep, "_spikes_df") or recompute:
287+
sweep.process_spikes()
288+
elif (
289+
feature_name in sweep.added_spike_features # prevents RecursionError
290+
and feature_name not in sweep._spikes_df.columns
291+
):
292+
sweep.process_spikes()
293+
294+
return sweep.spike_feature(feature_name, include_clipped=True)
295+
242296
def get_value(
243297
self, recompute: bool = False, store_diagnostics: bool = True
244298
) -> float:
@@ -430,9 +484,6 @@ class SpikeFeature(BaseFeature):
430484
compute spike features with `EphysSweep.process_spikes`, while being able to
431485
provide additional functionality to the spike feature class.
432486
433-
Currently, no diagnostics or recursive feature lookup is supported for spike
434-
features! For now this class mainly just acts as a feature function.
435-
436487
The description of the feature should contain a short description of the
437488
feature, and a list of dependencies. The dependencies should be listed
438489
as a comma separated list of feature names. It is parsed and can be displayed
@@ -449,19 +500,21 @@ class SpikeFeature(BaseFeature):
449500
450501
<Some more text>'''
451502
452-
All computed features are added to the underlying `EphysSweep`
453-
object, and can be accessed via `lookup_spike_feature`. The methods will
454-
first check if the feature is already computed, and if not, instantiate and
455-
compute it. Any dependencies already computed will be reused, unless
456-
`recompute=True` is passed.
503+
All computed features are added to `EphysSweep.added_spike_features`, and
504+
can be accessed via `lookup_spike_feature`. The methods will first check if
505+
the feature is already computed, and if not, instantiate and compute it.
506+
This works recursively, so that features can depend on other features as
507+
long as they are looked up with `lookup_spike_feature`. Hence any feature
508+
can be computed at any point, without having to compute any dependencies first.
509+
Any dependencies already computed will be reused, unless `recompute=True` is
510+
passed.
457511
458512
`SpikeFeature`s can also implement a _plot method, the feature. If the
459513
feature cannot be displayed in a V(t) or I(t) plot, instead the `plot` method
460514
should be overwritten directly. This is because `plot` wraps `_plot` adds
461515
additional functionality ot it.
462516
"""
463517

464-
# TODO: Add support for recursive feature lookup
465518
def __init__(
466519
self,
467520
data: Optional[EphysSweep] = None,
@@ -504,33 +557,6 @@ def _data_init(self, data: EphysSweep):
504557
self.type = type(data).__name__
505558
self.ensure_correct_hyperparams()
506559

507-
def lookup_spike_feature(
508-
self, feature_name: str, recompute: bool = False
509-
) -> ndarray:
510-
"""Look up a spike level feature and return its value.
511-
512-
This method will first check if the feature is already computed,
513-
and if not, compute all spike level features using `process_spikes` from
514-
the underlying `EphysSweep` object, and then
515-
instantiate and compute the feature.
516-
517-
Args:
518-
feature_name: Name of the feature to look up.
519-
recompute: If True, recompute the feature even if it is already
520-
computed.
521-
522-
Returns:
523-
The value of the feature for each detected spike.
524-
"""
525-
if not hasattr(self.data, "_spikes_df") or recompute:
526-
self.data.process_spikes()
527-
elif (
528-
feature_name in self.data.added_spike_features
529-
and feature_name not in self.data._spikes_df.columns
530-
):
531-
self.data.process_spikes()
532-
return self.data.spike_feature(feature_name, include_clipped=True)
533-
534560
def __str__(self):
535561
name = f"{self.name}\n"
536562
vals = "\n".join(
@@ -707,33 +733,6 @@ def lookup_sweep_feature(
707733
return ft.get_value(recompute=recompute)
708734
return ft
709735

710-
def lookup_spike_feature(
711-
self, feature_name: str, recompute: bool = False
712-
) -> ndarray:
713-
"""Look up a spike level feature and return its value.
714-
715-
This method will first check if the feature is already computed,
716-
and if not, compute all spike level features using `process_spikes` from
717-
the underlying `EphysSweep` object, and then
718-
instantiate and compute the feature.
719-
720-
Args:
721-
feature_name: Name of the feature to look up.
722-
recompute: If True, recompute the feature even if it is already
723-
computed.
724-
725-
Returns:
726-
The value of the feature for each detected spike.
727-
"""
728-
if not hasattr(self.data, "_spikes_df") or recompute:
729-
self.data.process_spikes()
730-
elif (
731-
feature_name in self.data.added_spike_features
732-
and feature_name not in self.data._spikes_df.columns
733-
):
734-
self.data.process_spikes()
735-
return self.data.spike_feature(feature_name, include_clipped=True)
736-
737736
def __repr__(self):
738737
return f"{self.name} for {self.data}"
739738

ephyspy/features/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,46 @@ class FeatureError(ValueError):
105105
pass
106106

107107

108+
def get_allensdk_spike_features():
109+
return [
110+
"threshold_index",
111+
"threshold_t",
112+
"threshold_v",
113+
"threshold_i",
114+
"peak_index",
115+
"peak_t",
116+
"peak_v",
117+
"peak_i",
118+
"trough_index",
119+
"trough_t",
120+
"trough_v",
121+
"trough_i",
122+
"upstroke_index",
123+
"upstroke",
124+
"upstroke_t",
125+
"upstroke_v",
126+
"downstroke_index",
127+
"downstroke",
128+
"downstroke_t",
129+
"downstroke_v",
130+
"isi_type",
131+
"fast_trough_index",
132+
"fast_trough_t",
133+
"fast_trough_v",
134+
"fast_trough_i",
135+
"adp_index",
136+
"adp_t",
137+
"adp_v",
138+
"adp_i",
139+
"slow_trough_index",
140+
"slow_trough_t",
141+
"slow_trough_v",
142+
"slow_trough_i",
143+
"width",
144+
"upstroke_downstroke_ratio",
145+
]
146+
147+
108148
class during_stimulus_only:
109149
def __init__(self, sweep, T_stim=None):
110150
self.sweep = sweep

tests/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
filter=5,
2323
metadata={"dc_offset": -14.52083},
2424
)
25-
test_sweepset.add_features(available_spike_features())
25+
# test_sweepset.add_features(available_spike_features())
2626

2727
# create test sweeps
2828
depol_test_sweep = EphysSweep(t_set[11], u_set[11], i_set[11], start, end, filter=1)

tests/test_features.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ def test_spike_feature(ft_func, sweep, is_depol):
7070

7171
# test value, diagnostics etc.
7272

73-
depol_test_sweep.add_features(available_spike_features())
74-
hyperpol_test_sweep.add_features(available_spike_features())
75-
7673

7774
@pytest.mark.parametrize(
7875
"Ft", available_sweep_features().values(), ids=available_sweep_features().keys()

0 commit comments

Comments
 (0)