diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index 1131f1597c5..f08e6f35e9a 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -23,6 +23,7 @@ __all__ = [ "cross_val_multiscore", "get_coef", "get_spatial_filter_from_estimator", + "read_ssd", ] from .base import ( BaseEstimator, @@ -36,7 +37,7 @@ from .ems import EMS, compute_ems from .receptive_field import ReceptiveField from .search_light import GeneralizingEstimator, SlidingEstimator from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator -from .ssd import SSD +from .ssd import SSD, read_ssd from .time_delaying_ridge import TimeDelayingRidge from .time_frequency import TimeFrequency from .transformer import ( diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 3a51a04bed7..2e8c775be2f 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -135,6 +135,31 @@ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._is_base_ged = False + def __getstate__(self): + """Get state for serialization. + + This explicitly drops callables and other runtime-only attributes. + Subclasses can extend this to add estimator-specific state. + """ + state = self.__dict__.copy() + + # Callables are not serializable and must be reconstructed + state.pop("cov_callable", None) + state.pop("mod_ged_callable", None) + + return state + + def __setstate__(self, state): + """Restore state from serialization. + + Subclasses are responsible for reconstructing dropped callables. + """ + self.__dict__.update(state) + + # Ensure attributes exist even before reconstruction + self.cov_callable = None + self.mod_ged_callable = None + def fit(self, X, y=None): """...""" # Let the inheriting transformers check data by themselves diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 41d67ece8c6..0606aa468e8 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -7,6 +7,8 @@ import numpy as np +from mne import __version__ as mne_version + from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..filter import filter_data @@ -15,6 +17,7 @@ fill_doc, logger, ) +from ..utils.check import _check_fname, _import_h5io_funcs, check_fname from ._covs_ged import _ssd_estimate from ._mod_ged import _get_spectral_ratio, _ssd_mod from .base import _GEDTransformer @@ -146,6 +149,14 @@ def __init__( restr_type=restr_type, ) + def __setstate__(self, state): + """Restore state from serialization.""" + # Since read_ssd creates a new instance via __init__ first, + # callables are already set correctly. We just restore fitted attributes. + # Don't call super().__setstate__() as it would set callables to None. + self.__dict__.update(state) + return self + def _validate_params(self, X): if isinstance(self.info, float): # special case, mostly for testing self.sfreq_ = self.info @@ -237,6 +248,33 @@ def fit(self, X, y=None): logger.info("Done.") return self + @fill_doc + def save(self, fname, *, overwrite=False, verbose=None): + """Save the SSD decomposition to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + Path of file to save to. Should end with ``'.h5'`` or ``'.hdf5'``. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.decoding.read_ssd + """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, "SSD", (".h5", ".hdf5")) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) + state = self.__getstate__() + state.update( + class_name="SSD", + mne_version=mne_version, + ) + write_hdf5( + fname, state, overwrite=overwrite, title="mnepython", slash="replace" + ) + def transform(self, X): """Estimate epochs sources given the SSD filters. @@ -350,3 +388,26 @@ def apply(self, X): pick_patterns = self.patterns_[: self.n_components].T X = pick_patterns @ X_ssd return X + + +def read_ssd(fname): + """Read an SSD object from disk. + + Parameters + ---------- + fname : path-like + Path to an SSD file in HDF5 format, which should end with ``.h5`` or + ``.hdf5``. + + Returns + ------- + ssd : SSD + The loaded SSD object. + """ + read_hdf5, _ = _import_h5io_funcs() + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + return SSD(info=None, filt_params_signal=None, filt_params_noise=None).__setstate__( + state + ) diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 236e65b82fd..ea577544fa2 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -16,7 +16,7 @@ from mne import Epochs, create_info, io, pick_types, read_events from mne._fiff.pick import _picks_to_idx -from mne.decoding import CSP +from mne.decoding import CSP, read_ssd from mne.decoding._mod_ged import _get_spectral_ratio from mne.decoding.ssd import SSD from mne.filter import filter_data @@ -361,6 +361,47 @@ def test_ssd_pipeline(): assert pipe.get_params()["SSD__n_components"] == 5 +def test_ssd_save_load(tmp_path): + """Test saving and loading of SSD.""" + X, _, _ = simulate_data() + sf = 250 + n_channels = X.shape[0] + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=5, + sort_by_spectral_ratio=True, + ) + ssd.fit(X) + + fname = tmp_path / "ssd.h5" + ssd.save(fname) + + ssd_loaded = read_ssd(fname) + + # Check numerical equivalence + X_orig = ssd.transform(X) + X_loaded = ssd_loaded.transform(X) + + assert_array_almost_equal(X_orig, X_loaded) + + def test_sorting(): """Test sorting learning during training.""" X, _, _ = simulate_data(n_trials=100, n_channels=20, n_samples=500)