-
Notifications
You must be signed in to change notification settings - Fork 21
[WIP] TimeFrequency CSP Class #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,255 @@ | ||
| # -*- coding: utf-8 -*- | ||
| # Authors: Laura Gwilliams <laura.gwilliams@nyu.edu> | ||
| # Jean-Remi King <jeanremi.king@gmail.com> | ||
| # Alex Barachant <alexandre.barachant@gmail.com> | ||
| # Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr> | ||
| # | ||
| # License: BSD (3-clause) | ||
|
|
||
| import numpy as np | ||
| from mne.filter import filter_data | ||
|
|
||
|
|
||
| class TimeFrequencyCSP(object): | ||
| u"""Decode MEG data in time-frequency space using a rolling covariance | ||
| matrix. | ||
|
|
||
| This object can be used as a supervised decomposition to estimate | ||
| spatial filters for feature extraction in a 2 class decoding problem. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| tmin : float | ||
| Time in seconds. Lower bound of the time over which to estimate | ||
| signals. | ||
| tmax : float | ||
| Times in seconds. Upper bound of the time over which to estimate | ||
| signals. | ||
| freqs : array of floats | ||
| Frequency values over which to compute time-frequency decomposition. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate a bit on how exactly to pass the frequency ranges? Maybe give an example here? |
||
| estimator : object | ||
| Scikit-learn classifier object. | ||
| sfreq : float | ||
| Sample frequency of data. | ||
| n_cycles : float, defaults to 7. | ||
| Number of cycles that each frequency should complete in forming the | ||
| sliding temporal window. | ||
| The time-window length is T = n_cycles / freq. | ||
| scorer : object | None | str | ||
| scikit-learn Scorer instance or str type indicating the name of the | ||
| scorer such as ``accuracy``, ``roc_auc``. If None, set to ``accuracy``. | ||
| n_jobs : int | ||
| Number of jobs to run in parallel. Defaults to 1. | ||
|
|
||
| Attributes | ||
| ---------- | ||
| ``filters_`` : ndarray, shape (n_channels, n_channels) | ||
| If fit, the CSP components used to decompose the data, else None. | ||
| ``patterns_`` : ndarray, shape (n_channels, n_channels) | ||
| If fit, the CSP patterns used to restore M/EEG signals, else None. | ||
| ``mean_`` : ndarray, shape (n_components,) | ||
| If fit, the mean squared power for each component. | ||
| ``std_`` : ndarray, shape (n_components,) | ||
| If fit, the std squared power for each component. | ||
| """ | ||
|
|
||
| def __init__(self, tmin, tmax, freqs, estimator, sfreq, n_cycles=7., | ||
| scorer=None, n_jobs=1): | ||
| """Init of TimeFrequencyCSP.""" | ||
|
|
||
| # check frequency list | ||
| if len(freqs) < 2: | ||
| raise ValueError('freqs must contain more than one value.') | ||
| self.freqs = freqs | ||
|
|
||
| # check sfreq | ||
| if not isinstance(sfreq, (float, int)): | ||
| raise ValueError('sfreq must be a float or an int, got %s ' | ||
| 'instead.' % type(sfreq)) | ||
| self.sfreq = float(sfreq) | ||
|
|
||
| # check n_cycles | ||
| if not isinstance(n_cycles, (float, int)): | ||
| raise ValueError('sfreq must be a float or an int, got %s ' | ||
| 'instead.' % type(n_cycles)) | ||
| self.n_cycles = float(n_cycles) | ||
|
|
||
| # TODO:add test that the number of cycles is appropriate for epochs | ||
|
|
||
| self.estimator = estimator | ||
| self.scorer = scorer | ||
| self.n_jobs = n_jobs | ||
|
|
||
| # Assemble list of frequency range tuples | ||
| self._freq_ranges = zip(self.freqs[:-1], freqs[1:]) | ||
|
|
||
| # Infer window spacing from the max freq and n cycles to avoid gaps | ||
| self._window_spacing = (n_cycles / np.max(freqs) / 2.) | ||
| self._centered_w_times = np.arange(tmin + self._window_spacing, | ||
| tmax - self._window_spacing, | ||
| self._window_spacing)[1:] # noqa | ||
| self._n_windows = len(self._centered_w_times) | ||
|
|
||
| def _transform(self, epochs, y, w_tmin, w_tmax, method='', pos=[]): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do the |
||
| """ | ||
| Assumes data for one frequency band for one time window and fits CSP. | ||
| """ | ||
| from sklearn import clone | ||
|
|
||
| # Crop data into time-window of interest | ||
| Xt = epochs.copy().crop(w_tmin, w_tmax).get_data() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why |
||
|
|
||
| # call fit or predict depending on calling function | ||
| if method == 'fit': | ||
| self.estimator = clone(self.estimator) | ||
| self._estimators[pos] = (self.estimator.fit(Xt, y)) | ||
| return self | ||
|
|
||
| elif method == 'predict': | ||
| self.y_pred = self._estimators[pos].predict(Xt) | ||
|
|
||
| elif method == 'score': | ||
| self.y_pred = self._estimators[pos].predict(Xt) | ||
| return self._estimators[pos].score(Xt, y) | ||
|
|
||
| def fit(self, epochs, y=None): | ||
| """Train a classifier on each specified frequency and time slice. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| epochs : instance of Epochs | ||
| The epochs. | ||
| y : list or ndarray of int, shape (n_samples,) or None, optional | ||
| To-be-fitted model values. If None, y = epochs.events[:, 2]. | ||
|
|
||
| Returns | ||
| ------- | ||
| self : TimeFrequencyCSP | ||
| Returns fitted TimeFrequencyCSP object. | ||
| """ | ||
| n_freq = len(self._freq_ranges) | ||
| n_window = len(self._centered_w_times) | ||
| self._scores = np.zeros([n_freq, n_window]) | ||
| self._estimators = np.empty([n_freq, n_window], dtype=object) | ||
|
|
||
| if y is None: | ||
| y = epochs.events[:, 2] | ||
|
|
||
| # loop through each frequency range | ||
| for freq_ii, (fmin, fmax) in enumerate(self._freq_ranges): | ||
|
|
||
| # Infer window size based on the frequency being used | ||
| w_size = self.n_cycles / ((fmax + fmin) / 2.) # in seconds | ||
|
|
||
| # filter the data at the desired frequency | ||
| epochs_bp = epochs.copy() | ||
| epochs_bp._data = filter_data(epochs_bp.get_data(), self.sfreq, | ||
| fmin, fmax, verbose='CRITICAL') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A thought: maybe it is helpful to expose some parameters for the filtering? Perhaps not, as the number of parameters to this class will grow quite large. |
||
|
|
||
| # Roll covariance, csp and lda over time | ||
| for t, w_time in enumerate(self._centered_w_times): | ||
|
|
||
| # Center the min and max of the window | ||
| w_tmin = w_time - w_size / 2. | ||
| w_tmax = w_time + w_size / 2. | ||
|
|
||
| # extract data for this time window | ||
| self._transform(epochs_bp, y, w_tmin, w_tmax, method='fit', | ||
| pos=(freq_ii, t)) | ||
| pass | ||
|
|
||
| def predict(self, epochs): | ||
| """Test each classifier on each specified testing frequency and time | ||
| slice. | ||
|
|
||
| .. note:: | ||
| This function sets the ``y_pred_`` and ``test_times_`` attributes. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| epochs : instance of Epochs | ||
| The epochs. Can be similar to fitted epochs or not. See | ||
| predict_mode parameter. | ||
|
|
||
| Returns | ||
| ------- | ||
| y_pred : | ||
| The single-trial predictions at each time window. | ||
| """ | ||
| # loop through each frequency range | ||
| for freq_ii, (fmin, fmax) in enumerate(self._freq_ranges): | ||
|
|
||
| # Infer window size based on the frequency being used | ||
| w_size = self.n_cycles / ((fmax + fmin) / 2.) # in seconds | ||
|
|
||
| # filter the data at the desired frequency | ||
| epochs_bp = epochs.copy() | ||
| epochs_bp._data = filter_data(epochs_bp.get_data(), self.sfreq, | ||
| fmin, fmax, verbose='CRITICAL') | ||
|
|
||
| # Roll covariance, csp and lda over time | ||
| for t, w_time in enumerate(self._centered_w_times): | ||
|
|
||
| # Center the min and max of the window | ||
| w_tmin = w_time - w_size / 2. | ||
| w_tmax = w_time + w_size / 2. | ||
|
|
||
| # extract data for this time window | ||
| self._transform(epochs_bp, None, w_tmin, w_tmax, | ||
| method='predict', pos=(freq_ii, t)) | ||
| pass | ||
|
|
||
| def score(self, epochs, y): | ||
| """Score Epochs. | ||
|
|
||
| Estimate scores across trials by comparing the prediction estimated for | ||
| each trial to its true value. | ||
|
|
||
| Calls ``predict()``. | ||
|
|
||
| .. note:: | ||
| The function updates the ``scorer_``, ``scores_``, and | ||
| ``y_true_`` attributes. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| epochs : instance of Epochs | None, optional | ||
| The epochs. Can be similar to fitted epochs or not. | ||
| If None, it needs to rely on the predictions ``y_pred_`` | ||
| generated with ``predict()``. | ||
| y : list | ndarray, shape (n_epochs,) | None, optional | ||
| True values to be compared with the predictions ``y_pred_`` | ||
| generated with ``predict()`` via ``scorer_``. | ||
|
|
||
| Returns | ||
| ------- | ||
| scores : list of float, shape (n_times,) | ||
| The scores estimated by ``scorer_`` at each time sample (e.g. mean | ||
| accuracy of ``predict(X)``). | ||
| """ | ||
| if y is None: | ||
| y = epochs.events[:, 2] | ||
|
|
||
| # loop through each frequency range | ||
| for freq_ii, (fmin, fmax) in enumerate(self._freq_ranges): | ||
| # Infer window size based on the frequency being used | ||
| w_size = self.n_cycles / ((fmax + fmin) / 2.) # in seconds | ||
|
|
||
| # filter the data at the desired frequency | ||
| epochs_bp = epochs.copy() | ||
| epochs_bp._data = filter_data(epochs_bp.get_data(), self.sfreq, | ||
| fmin, fmax, verbose='CRITICAL') | ||
|
|
||
| # Roll covariance, csp and lda over time | ||
| for t, w_time in enumerate(self._centered_w_times): | ||
|
|
||
| # Center the min and max of the window | ||
| w_tmin = w_time - w_size / 2. | ||
| w_tmax = w_time + w_size / 2. | ||
|
|
||
| # extract data for this time window | ||
| s = self._transform(epochs_bp, y, w_tmin, w_tmax, | ||
| method='score', pos=(freq_ii, t)) | ||
| self._scores[freq_ii, t] = s | ||
|
|
||
| return self._scores | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this class follows the scikit-learn API, can we make this a subclass of
sklearn.base.BaseEstimatorand add thesklearn.base.TransformerMixin?