2424from matplotlib .pyplot import Axes
2525from numpy import ndarray
2626
27- from ephyspy .features .utils import FeatureError , fetch_available_fts
27+ from ephyspy .features .utils import (
28+ FeatureError ,
29+ fetch_available_fts ,
30+ get_allensdk_spike_features ,
31+ )
2832from ephyspy .sweeps import EphysSweep , EphysSweepSet
2933from ephyspy .utils import (
34+ is_spike_feature ,
3035 is_sweep_feature ,
3136 is_sweepset_feature ,
3237 parse_deps ,
@@ -239,6 +244,55 @@ def _update_diagnostics(self, dct: Dict[str, Any]):
239244 self ._diagnostics = {}
240245 self ._diagnostics .update (dct )
241246
247+ def lookup_spike_feature (
248+ self , feature_name : str , recompute : bool = False
249+ ) -> ndarray :
250+ """Look up a spike level feature and return its value.
251+
252+ This method will first check if the feature is already computed,
253+ and if not, compute all spike level features using `process_spikes` from
254+ the underlying `EphysSweep` object, and then
255+ instantiate and compute the feature. Lookup is recursive and considers all
256+ registered and implemented spike features.
257+
258+ Args:
259+ feature_name: Name of the feature to look up.
260+ recompute: If True, recompute the feature even if it is already
261+ computed.
262+
263+ Returns:
264+ The value of the feature for each detected spike.
265+
266+ Raises:
267+ FeatureError: If the feature is not found via `fetch_available_fts`.
268+ """
269+ sweep = self .data
270+ is_allen_ft = feature_name in get_allensdk_spike_features ()
271+ ft_already_added = feature_name in sweep .added_spike_features
272+ if not (is_allen_ft or ft_already_added ):
273+ available_fts = fetch_available_fts ()
274+ available_fts = [ft for ft in available_fts if is_spike_feature (ft )]
275+ available_fts = {
276+ ft .__name__ .lower ().replace ("spike_" , "" ): ft for ft in available_fts
277+ }
278+ if feature_name in available_fts :
279+ feature = available_fts [feature_name ](sweep , compute_at_init = False )
280+ sweep .add_spike_feature (feature .name , feature )
281+ else :
282+ raise FeatureError (
283+ f"{ feature_name } was not found. Make sure it is implemented or registered."
284+ )
285+
286+ if not hasattr (sweep , "_spikes_df" ) or recompute :
287+ sweep .process_spikes ()
288+ elif (
289+ feature_name in sweep .added_spike_features # prevents RecursionError
290+ and feature_name not in sweep ._spikes_df .columns
291+ ):
292+ sweep .process_spikes ()
293+
294+ return sweep .spike_feature (feature_name , include_clipped = True )
295+
242296 def get_value (
243297 self , recompute : bool = False , store_diagnostics : bool = True
244298 ) -> float :
@@ -430,9 +484,6 @@ class SpikeFeature(BaseFeature):
430484 compute spike features with `EphysSweep.process_spikes`, while being able to
431485 provide additional functionality to the spike feature class.
432486
433- Currently, no diagnostics or recursive feature lookup is supported for spike
434- features! For now this class mainly just acts as a feature function.
435-
436487 The description of the feature should contain a short description of the
437488 feature, and a list of dependencies. The dependencies should be listed
438489 as a comma separated list of feature names. It is parsed and can be displayed
@@ -449,19 +500,21 @@ class SpikeFeature(BaseFeature):
449500
450501 <Some more text>'''
451502
452- All computed features are added to the underlying `EphysSweep`
453- object, and can be accessed via `lookup_spike_feature`. The methods will
454- first check if the feature is already computed, and if not, instantiate and
455- compute it. Any dependencies already computed will be reused, unless
456- `recompute=True` is passed.
503+ All computed features are added to `EphysSweep.added_spike_features`, and
504+ can be accessed via `lookup_spike_feature`. The methods will first check if
505+ the feature is already computed, and if not, instantiate and compute it.
506+ This works recursively, so that features can depend on other features as
507+ long as they are looked up with `lookup_spike_feature`. Hence any feature
508+ can be computed at any point, without having to compute any dependencies first.
509+ Any dependencies already computed will be reused, unless `recompute=True` is
510+ passed.
457511
458512 `SpikeFeature`s can also implement a _plot method, the feature. If the
459513 feature cannot be displayed in a V(t) or I(t) plot, instead the `plot` method
460514 should be overwritten directly. This is because `plot` wraps `_plot` adds
461515 additional functionality ot it.
462516 """
463517
464- # TODO: Add support for recursive feature lookup
465518 def __init__ (
466519 self ,
467520 data : Optional [EphysSweep ] = None ,
@@ -504,33 +557,6 @@ def _data_init(self, data: EphysSweep):
504557 self .type = type (data ).__name__
505558 self .ensure_correct_hyperparams ()
506559
507- def lookup_spike_feature (
508- self , feature_name : str , recompute : bool = False
509- ) -> ndarray :
510- """Look up a spike level feature and return its value.
511-
512- This method will first check if the feature is already computed,
513- and if not, compute all spike level features using `process_spikes` from
514- the underlying `EphysSweep` object, and then
515- instantiate and compute the feature.
516-
517- Args:
518- feature_name: Name of the feature to look up.
519- recompute: If True, recompute the feature even if it is already
520- computed.
521-
522- Returns:
523- The value of the feature for each detected spike.
524- """
525- if not hasattr (self .data , "_spikes_df" ) or recompute :
526- self .data .process_spikes ()
527- elif (
528- feature_name in self .data .added_spike_features
529- and feature_name not in self .data ._spikes_df .columns
530- ):
531- self .data .process_spikes ()
532- return self .data .spike_feature (feature_name , include_clipped = True )
533-
534560 def __str__ (self ):
535561 name = f"{ self .name } \n "
536562 vals = "\n " .join (
@@ -707,33 +733,6 @@ def lookup_sweep_feature(
707733 return ft .get_value (recompute = recompute )
708734 return ft
709735
710- def lookup_spike_feature (
711- self , feature_name : str , recompute : bool = False
712- ) -> ndarray :
713- """Look up a spike level feature and return its value.
714-
715- This method will first check if the feature is already computed,
716- and if not, compute all spike level features using `process_spikes` from
717- the underlying `EphysSweep` object, and then
718- instantiate and compute the feature.
719-
720- Args:
721- feature_name: Name of the feature to look up.
722- recompute: If True, recompute the feature even if it is already
723- computed.
724-
725- Returns:
726- The value of the feature for each detected spike.
727- """
728- if not hasattr (self .data , "_spikes_df" ) or recompute :
729- self .data .process_spikes ()
730- elif (
731- feature_name in self .data .added_spike_features
732- and feature_name not in self .data ._spikes_df .columns
733- ):
734- self .data .process_spikes ()
735- return self .data .spike_feature (feature_name , include_clipped = True )
736-
737736 def __repr__ (self ):
738737 return f"{ self .name } for { self .data } "
739738
0 commit comments