Skip to content
Open
3 changes: 2 additions & 1 deletion mne/decoding/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ __all__ = [
"cross_val_multiscore",
"get_coef",
"get_spatial_filter_from_estimator",
"read_ssd",
]
from .base import (
BaseEstimator,
Expand All @@ -36,7 +37,7 @@ from .ems import EMS, compute_ems
from .receptive_field import ReceptiveField
from .search_light import GeneralizingEstimator, SlidingEstimator
from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator
from .ssd import SSD
from .ssd import SSD, read_ssd
from .time_delaying_ridge import TimeDelayingRidge
from .time_frequency import TimeFrequency
from .transformer import (
Expand Down
25 changes: 25 additions & 0 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,31 @@ def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._is_base_ged = False

def __getstate__(self):
"""Get state for serialization.
This explicitly drops callables and other runtime-only attributes.
Subclasses can extend this to add estimator-specific state.
"""
state = self.__dict__.copy()

# Callables are not serializable and must be reconstructed
state.pop("cov_callable", None)
state.pop("mod_ged_callable", None)

return state

def __setstate__(self, state):
"""Restore state from serialization.
Subclasses are responsible for reconstructing dropped callables.
"""
self.__dict__.update(state)

# Ensure attributes exist even before reconstruction
self.cov_callable = None
self.mod_ged_callable = None

def fit(self, X, y=None):
"""..."""
# Let the inheriting transformers check data by themselves
Expand Down
61 changes: 61 additions & 0 deletions mne/decoding/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import numpy as np

from mne import __version__ as mne_version

from .._fiff.meas_info import Info, create_info
from .._fiff.pick import _picks_to_idx
from ..filter import filter_data
Expand All @@ -15,6 +17,7 @@
fill_doc,
logger,
)
from ..utils.check import _check_fname, _import_h5io_funcs, check_fname
from ._covs_ged import _ssd_estimate
from ._mod_ged import _get_spectral_ratio, _ssd_mod
from .base import _GEDTransformer
Expand Down Expand Up @@ -146,6 +149,14 @@ def __init__(
restr_type=restr_type,
)

def __setstate__(self, state):
"""Restore state from serialization."""
# Since read_ssd creates a new instance via __init__ first,
# callables are already set correctly. We just restore fitted attributes.
# Don't call super().__setstate__() as it would set callables to None.
self.__dict__.update(state)
return self

def _validate_params(self, X):
if isinstance(self.info, float): # special case, mostly for testing
self.sfreq_ = self.info
Expand Down Expand Up @@ -237,6 +248,33 @@ def fit(self, X, y=None):
logger.info("Done.")
return self

@fill_doc
def save(self, fname, *, overwrite=False, verbose=None):
"""Save the SSD decomposition to disk (in HDF5 format).

Parameters
----------
fname : path-like
Path of file to save to. Should end with ``'.h5'`` or ``'.hdf5'``.
%(overwrite)s
%(verbose)s

See Also
--------
mne.decoding.read_ssd
"""
_, write_hdf5 = _import_h5io_funcs()
check_fname(fname, "SSD", (".h5", ".hdf5"))
fname = _check_fname(fname, overwrite=overwrite, verbose=verbose)
state = self.__getstate__()
state.update(
class_name="SSD",
mne_version=mne_version,
)
write_hdf5(
fname, state, overwrite=overwrite, title="mnepython", slash="replace"
)

def transform(self, X):
"""Estimate epochs sources given the SSD filters.

Expand Down Expand Up @@ -350,3 +388,26 @@ def apply(self, X):
pick_patterns = self.patterns_[: self.n_components].T
X = pick_patterns @ X_ssd
return X


def read_ssd(fname):
"""Read an SSD object from disk.

Parameters
----------
fname : path-like
Path to an SSD file in HDF5 format, which should end with ``.h5`` or
``.hdf5``.

Returns
-------
ssd : SSD
The loaded SSD object.
"""
read_hdf5, _ = _import_h5io_funcs()
_validate_type(fname, "path-like", "fname")
fname = _check_fname(fname=fname, overwrite="read", must_exist=False)
state = read_hdf5(fname, title="mnepython", slash="replace")
return SSD(info=None, filt_params_signal=None, filt_params_noise=None).__setstate__(
state
)
43 changes: 42 additions & 1 deletion mne/decoding/tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from mne import Epochs, create_info, io, pick_types, read_events
from mne._fiff.pick import _picks_to_idx
from mne.decoding import CSP
from mne.decoding import CSP, read_ssd
from mne.decoding._mod_ged import _get_spectral_ratio
from mne.decoding.ssd import SSD
from mne.filter import filter_data
Expand Down Expand Up @@ -361,6 +361,47 @@ def test_ssd_pipeline():
assert pipe.get_params()["SSD__n_components"] == 5


def test_ssd_save_load(tmp_path):
"""Test saving and loading of SSD."""
X, _, _ = simulate_data()
sf = 250
n_channels = X.shape[0]
info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg")

filt_params_signal = dict(
l_freq=freqs_sig[0],
h_freq=freqs_sig[1],
l_trans_bandwidth=1,
h_trans_bandwidth=1,
)
filt_params_noise = dict(
l_freq=freqs_noise[0],
h_freq=freqs_noise[1],
l_trans_bandwidth=1,
h_trans_bandwidth=1,
)

ssd = SSD(
info,
filt_params_signal,
filt_params_noise,
n_components=5,
sort_by_spectral_ratio=True,
)
ssd.fit(X)

fname = tmp_path / "ssd.h5"
ssd.save(fname)

ssd_loaded = read_ssd(fname)

# Check numerical equivalence
X_orig = ssd.transform(X)
X_loaded = ssd_loaded.transform(X)

assert_array_almost_equal(X_orig, X_loaded)


def test_sorting():
"""Test sorting learning during training."""
X, _, _ = simulate_data(n_trials=100, n_channels=20, n_samples=500)
Expand Down