|
22 | 22 | import numpy as np |
23 | 23 | from matplotlib.pyplot import Axes |
24 | 24 | from numpy import ndarray |
| 25 | +from pandas import DataFrame |
25 | 26 |
|
26 | 27 | import ephyspy.allen_sdk.ephys_extractor as efex |
27 | 28 | from ephyspy.allen_sdk.ephys_extractor import ( |
@@ -151,6 +152,21 @@ def get_features(self, recompute: bool = False) -> Dict[str, float]: |
151 | 152 | for k, ft in self.features.items() |
152 | 153 | } |
153 | 154 |
|
| 155 | + def get_spike_features(self, recompute: bool = False) -> DataFrame: |
| 156 | + """Compute all spike features that have been added to the `EphysSweep` instance. |
| 157 | +
|
| 158 | + Includes all features that can be found in `self.added_spike_features`. |
| 159 | +
|
| 160 | + Args: |
| 161 | + recompute (bool, optional): Whether to force recomputation of the |
| 162 | + features. Defaults to False. |
| 163 | +
|
| 164 | + Returns: |
| 165 | + DataFrame: DataFrame of features and values.""" |
| 166 | + if not hasattr(self, "_spikes_df") or recompute: |
| 167 | + self.process_spikes() |
| 168 | + return self._spikes_df |
| 169 | + |
154 | 170 | def clear_features(self): |
155 | 171 | """Clear all features.""" |
156 | 172 | self.spikes_df = None |
@@ -459,6 +475,25 @@ def get_sweep_features(self, recompute: bool = False) -> Dict[str, List[float]]: |
459 | 475 | LD = [sw.get_features(recompute=recompute) for sw in self.sweeps()] |
460 | 476 | return {k: [dic[k] for dic in LD] for k in LD[0]} |
461 | 477 |
|
| 478 | + def get_spike_features(self, recompute: bool = False) -> List[DataFrame]: |
| 479 | + """Collect spike features on a sweep level. |
| 480 | +
|
| 481 | + This computes / looks up all spike features that have been computed at |
| 482 | + the sweep level and returns them as a list of dataframes. Each dataframe |
| 483 | + contains the values for the respective feature for each spike, i.e. |
| 484 | + `get_spike_features()[sweep_idx][feature_name]` returns the values of |
| 485 | + `feature_name` for the `sweep_idx`-th sweep. |
| 486 | +
|
| 487 | + Args: |
| 488 | + recompute (bool, optional): Whether to force recomputation of the |
| 489 | + features. Defaults to False. |
| 490 | +
|
| 491 | + Returns: |
| 492 | + Dict[str, List[float]]: Dictionary of features and values. |
| 493 | + """ |
| 494 | + dfs = [sw.get_spike_features(recompute=recompute) for sw in self.sweeps()] |
| 495 | + return dfs |
| 496 | + |
462 | 497 | def plot( |
463 | 498 | self, ax: Optional[Axes] = None, show_stimulus: bool = False, **kwargs |
464 | 499 | ) -> Axes: |
|
0 commit comments