-
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
fix: added save and read functionality in SSD issue #13328 #13589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
|
@larsoner @mscheltienne Kindly review this PR and let me know if any changes are required |
…/mne-python into fix-ssd-save-load merge branches .
|
I'm happy to look over this @larsoner |
|
@Anushreebasics Thanks for opening the PR. It would be good to set some explicit Additionally, we should really also add this save/load behaviour to Please have a think about adding |
for more information, see https://pre-commit.ci
|
@tsbinns please review the changes |
tsbinns
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes @Anushreebasics. A couple comments and suggestions here.
Also, make sure you add any new functions to the __init__.pyi file, otherwise they can't be imported. This also means that the test which has been added can't run at the moment. Please have a go at running any tests you are adding/changing locally first, that way you'll have a much clearer idea of how the changes are (mis)behaving.
You can find info about running the tests in the contribution guide (here and here). Feel free to ask if you have questions about getting this running.
|
@tsbinns please review the changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| filt_params_noise=self.filt_params_noise, | ||
| rank=self.rank, | ||
| sort_by_spectral_ratio=self.sort_by_spectral_ratio, | ||
| ) |
Copilot
AI
Jan 18, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The setstate method is not reconstructing the mod_ged_callable attribute, which is required by the base class _GEDTransformer. According to the init method at line 148, this should be set to _ssd_mod. Add 'self.mod_ged_callable = _ssd_mod' after reconstructing the cov_callable.
| ) | |
| ) | |
| # Rebuild modulation GED callable exactly as in __init__ | |
| self.mod_ged_callable = _ssd_mod |
| "freqs_signal_", | ||
| "freqs_noise_", | ||
| "n_fft_", | ||
| "sfreq_", |
Copilot
AI
Jan 18, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fitted attribute 'sorter_' is not being explicitly preserved in getstate. This attribute is set by the parent class GEDTransformer.fit() method and may be needed for proper deserialization. Although the parent's getstate copies dict, explicitly including 'sorter' in the fitted attributes list would make the serialization more robust and clear.
| "sfreq_", | |
| "sfreq_", | |
| "sorter_", |
| return X | ||
|
|
||
|
|
||
| def read_ssd(fname): |
Copilot
AI
Jan 18, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The read_ssd function needs to be added to the public API. It should be included in mne/decoding/init.pyi's all list and imported in the stub file so users can import it from mne.decoding. The test imports it directly from mne.decoding, so it must be publicly exported.
| def save(self, fname, overwrite=False): | ||
| state = self.__getstate__() | ||
| state.update( | ||
| class_name="SSD", | ||
| mne_version=mne_version, | ||
| ) |
Copilot
AI
Jan 18, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The save method is missing its docstring. It should document the fname and overwrite parameters, describe what the method does, and specify the expected file format (.h5). Follow the docstring pattern used in other MNE save methods like the one in Beamformer.save().
| def save(self, fname, overwrite=False): | ||
| state = self.__getstate__() | ||
| state.update( | ||
| class_name="SSD", | ||
| mne_version=mne_version, | ||
| ) |
Copilot
AI
Jan 18, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The save method is incomplete - it creates the state dictionary but never actually writes it to disk. You need to call write_hdf5 to persist the data. The method should import and use write_hdf5 similar to how read_ssd uses read_hdf5. Add the missing write call after line 305.
| ssd = SSD( | ||
| info=state["info"], | ||
| filt_params_signal=state["filt_params_signal"], | ||
| filt_params_noise=state["filt_params_noise"], | ||
| reg=state["reg"], | ||
| n_components=state["n_components"], | ||
| picks=state["picks"], | ||
| sort_by_spectral_ratio=state["sort_by_spectral_ratio"], | ||
| return_filtered=state["return_filtered"], | ||
| n_fft=state["n_fft"], | ||
| cov_method_params=state["cov_method_params"], | ||
| restr_type=state["restr_type"], | ||
| rank=state["rank"], | ||
| ) |
Copilot
AI
Jan 18, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect indentation: the SSD instantiation (lines 446-459) is indented under the if statement on line 443, making it unreachable when the condition is True. This code should be unindented to execute after the validation check. The SSD instantiation should occur regardless of whether the check passes (as long as no exception is raised), so it needs to be at the same indentation level as the if statement.
| ssd = SSD( | |
| info=state["info"], | |
| filt_params_signal=state["filt_params_signal"], | |
| filt_params_noise=state["filt_params_noise"], | |
| reg=state["reg"], | |
| n_components=state["n_components"], | |
| picks=state["picks"], | |
| sort_by_spectral_ratio=state["sort_by_spectral_ratio"], | |
| return_filtered=state["return_filtered"], | |
| n_fft=state["n_fft"], | |
| cov_method_params=state["cov_method_params"], | |
| restr_type=state["restr_type"], | |
| rank=state["rank"], | |
| ) | |
| ssd = SSD( | |
| info=state["info"], | |
| filt_params_signal=state["filt_params_signal"], | |
| filt_params_noise=state["filt_params_noise"], | |
| reg=state["reg"], | |
| n_components=state["n_components"], | |
| picks=state["picks"], | |
| sort_by_spectral_ratio=state["sort_by_spectral_ratio"], | |
| return_filtered=state["return_filtered"], | |
| n_fft=state["n_fft"], | |
| cov_method_params=state["cov_method_params"], | |
| restr_type=state["restr_type"], | |
| rank=state["rank"], | |
| ) |
issue #13328
Summary
This PR fixes the SSD serialization logic so that objects saved with SSD.save() can be fully restored using SSD.read() without losing internal state.
Previously, some fitted attributes required to reconstruct a trained SSD instance were not persisted, which caused test_ssd_save_load to fail during deserialization.
What was changed?
Refactored the SSD.save() method to persist all fitted attributes required to restore a trained instance.
Updated the corresponding SSD.read() logic to correctly re-initialize these attributes when loading from disk.
Ensured that save/load performs a true round-trip and yields an SSD instance equivalent to the original fitted object.