Skip to content
25 changes: 25 additions & 0 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,31 @@ def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._is_base_ged = False

def __getstate__(self):
"""Get state for serialization.

This explicitly drops callables and other runtime-only attributes.
Subclasses can extend this to add estimator-specific state.
"""
state = self.__dict__.copy()

# Callables are not serializable and must be reconstructed
state.pop("cov_callable", None)
state.pop("mod_ged_callable", None)

return state

def __setstate__(self, state):
"""Restore state from serialization.

Subclasses are responsible for reconstructing dropped callables.
"""
self.__dict__.update(state)

# Ensure attributes exist even before reconstruction
self.cov_callable = None
self.mod_ged_callable = None

def fit(self, X, y=None):
"""..."""
# Let the inheriting transformers check data by themselves
Expand Down
112 changes: 112 additions & 0 deletions mne/decoding/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import numpy as np

from mne import __version__ as mne_version
from mne.utils import check_version

from .._fiff.meas_info import Info, create_info
from .._fiff.pick import _picks_to_idx
from ..filter import filter_data
Expand Down Expand Up @@ -146,6 +149,63 @@ def __init__(
restr_type=restr_type,
)

def __getstate__(self):
"""Get state for serialization."""
state = super().__getstate__()

# init parameters
state.update(
info=self.info,
filt_params_signal=self.filt_params_signal,
filt_params_noise=self.filt_params_noise,
reg=self.reg,
n_components=self.n_components,
picks=self.picks,
sort_by_spectral_ratio=self.sort_by_spectral_ratio,
return_filtered=self.return_filtered,
n_fft=self.n_fft,
cov_method_params=self.cov_method_params,
restr_type=self.restr_type,
rank=self.rank,
)

# fitted attributes (only if present)
for attr in (
"filters_",
"patterns_",
"evals_",
"picks_",
"freqs_signal_",
"freqs_noise_",
"n_fft_",
"sfreq_",
):
if hasattr(self, attr):
state[attr] = getattr(self, attr)

return state

def __setstate__(self, state):
"""Restore state from serialization."""
super().__setstate__(state)

# Restore attributes
self.__dict__.update(state)

# Rebuild covariance callable exactly as in __init__
self.cov_callable = partial(
_ssd_estimate,
reg=self.reg,
cov_method_params=self.cov_method_params,
info=self.info,
picks=self.picks,
n_fft=self.n_fft,
filt_params_signal=self.filt_params_signal,
filt_params_noise=self.filt_params_noise,
rank=self.rank,
sort_by_spectral_ratio=self.sort_by_spectral_ratio,
)

def _validate_params(self, X):
if isinstance(self.info, float): # special case, mostly for testing
self.sfreq_ = self.info
Expand Down Expand Up @@ -237,6 +297,13 @@ def fit(self, X, y=None):
logger.info("Done.")
return self

def save(self, fname, overwrite=False):
state = self.__getstate__()
state.update(
class_name="SSD",
mne_version=mne_version,
)

def transform(self, X):
"""Estimate epochs sources given the SSD filters.

Expand Down Expand Up @@ -350,3 +417,48 @@ def apply(self, X):
pick_patterns = self.patterns_[: self.n_components].T
X = pick_patterns @ X_ssd
return X


def read_ssd(fname):
"""Read an SSD object from disk.

Parameters
----------
fname : path-like
Path to ``.h5`` file.

Returns
-------
ssd : SSD
The loaded SSD object.
"""
from ..utils.check import _import_h5io_funcs

_validate_type(fname, "path-like", "fname")
check_version("h5py")

read_hdf5, _ = _import_h5io_funcs()
state = read_hdf5(fname, title="mne-python SSD")

if state.get("class_name") != "SSD":
raise RuntimeError("The file does not contain a valid SSD object.")

ssd = SSD(
info=state["info"],
filt_params_signal=state["filt_params_signal"],
filt_params_noise=state["filt_params_noise"],
reg=state["reg"],
n_components=state["n_components"],
picks=state["picks"],
sort_by_spectral_ratio=state["sort_by_spectral_ratio"],
return_filtered=state["return_filtered"],
n_fft=state["n_fft"],
cov_method_params=state["cov_method_params"],
restr_type=state["restr_type"],
rank=state["rank"],
)

# restore full state (fitted attributes + callables)
ssd.__setstate__(state)

return ssd
43 changes: 42 additions & 1 deletion mne/decoding/tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from mne import Epochs, create_info, io, pick_types, read_events
from mne._fiff.pick import _picks_to_idx
from mne.decoding import CSP
from mne.decoding import CSP, read_ssd
from mne.decoding._mod_ged import _get_spectral_ratio
from mne.decoding.ssd import SSD
from mne.filter import filter_data
Expand Down Expand Up @@ -361,6 +361,47 @@ def test_ssd_pipeline():
assert pipe.get_params()["SSD__n_components"] == 5


def test_ssd_save_load(tmp_path):
"""Test saving and loading of SSD."""
X, _, _ = simulate_data()
sf = 250
n_channels = X.shape[0]
info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg")

filt_params_signal = dict(
l_freq=freqs_sig[0],
h_freq=freqs_sig[1],
l_trans_bandwidth=1,
h_trans_bandwidth=1,
)
filt_params_noise = dict(
l_freq=freqs_noise[0],
h_freq=freqs_noise[1],
l_trans_bandwidth=1,
h_trans_bandwidth=1,
)

ssd = SSD(
info,
filt_params_signal,
filt_params_noise,
n_components=5,
sort_by_spectral_ratio=True,
)
ssd.fit(X)

fname = tmp_path / "ssd.h5"
ssd.save(fname)

ssd_loaded = read_ssd(fname)

# Check numerical equivalence
X_orig = ssd.transform(X)
X_loaded = ssd_loaded.transform(X)

assert_array_almost_equal(X_orig, X_loaded)


def test_sorting():
"""Test sorting learning during training."""
X, _, _ = simulate_data(n_trials=100, n_channels=20, n_samples=500)
Expand Down
Loading