-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
FIX: Add on_few_samples parameter to core rank estimation #13350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a8084cb
2014510
38184bc
b0ae9a2
3855a24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -553,16 +553,17 @@ def make_ad_hoc_cov(info, std=None, *, verbose=None): | |||||||||
return Covariance(data, ch_names, info["bads"], info["projs"], nfree=0) | ||||||||||
|
||||||||||
|
||||||||||
def _check_n_samples(n_samples, n_chan): | ||||||||||
def _check_n_samples(n_samples, n_chan, on_few_samples="warn"): | ||||||||||
"""Check to see if there are enough samples for reliable cov calc.""" | ||||||||||
n_samples_min = 10 * (n_chan + 1) // 2 | ||||||||||
if n_samples <= 0: | ||||||||||
raise ValueError("No samples found to compute the covariance matrix") | ||||||||||
if n_samples < n_samples_min: | ||||||||||
warn( | ||||||||||
msg = ( | ||||||||||
f"Too few samples (required : {n_samples_min} got : {n_samples}), " | ||||||||||
"covariance estimate may be unreliable" | ||||||||||
) | ||||||||||
_on_missing(on_few_samples, msg, "on_few_samples") | ||||||||||
|
||||||||||
|
||||||||||
@verbose | ||||||||||
|
@@ -582,6 +583,7 @@ def compute_raw_covariance( | |||||||||
return_estimators=False, | ||||||||||
reject_by_annotation=True, | ||||||||||
rank=None, | ||||||||||
on_few_samples="warn", | ||||||||||
verbose=None, | ||||||||||
): | ||||||||||
"""Estimate noise covariance matrix from a continuous segment of raw data. | ||||||||||
|
@@ -662,6 +664,10 @@ def compute_raw_covariance( | |||||||||
|
||||||||||
.. versionadded:: 0.18 | ||||||||||
Support for 'info' mode. | ||||||||||
on_few_samples : str | ||||||||||
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when | ||||||||||
there are fewer samples than channels, which can lead to inaccurate | ||||||||||
covariance or rank estimates. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
%(verbose)s | ||||||||||
|
||||||||||
Returns | ||||||||||
|
@@ -736,7 +742,7 @@ def compute_raw_covariance( | |||||||||
mu += raw_segment.sum(axis=1) | ||||||||||
data += np.dot(raw_segment, raw_segment.T) | ||||||||||
n_samples += raw_segment.shape[1] | ||||||||||
_check_n_samples(n_samples, len(picks)) | ||||||||||
_check_n_samples(n_samples, len(picks), on_few_samples) | ||||||||||
data -= mu[:, None] * (mu[None, :] / n_samples) | ||||||||||
data /= n_samples - 1.0 | ||||||||||
logger.info("Number of samples used : %d", n_samples) | ||||||||||
|
@@ -872,6 +878,7 @@ def compute_covariance( | |||||||||
return_estimators=False, | ||||||||||
on_mismatch="raise", | ||||||||||
rank=None, | ||||||||||
on_few_samples="warn", | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Here maybe after |
||||||||||
verbose=None, | ||||||||||
): | ||||||||||
"""Estimate noise covariance matrix from epochs. | ||||||||||
|
@@ -966,6 +973,10 @@ def compute_covariance( | |||||||||
|
||||||||||
.. versionadded:: 0.18 | ||||||||||
Support for 'info' mode. | ||||||||||
on_few_samples : str | ||||||||||
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when | ||||||||||
there are fewer samples than channels, which can lead to inaccurate | ||||||||||
covariance or rank estimates. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
%(verbose)s | ||||||||||
|
||||||||||
Returns | ||||||||||
|
@@ -1144,7 +1155,7 @@ def _unpack_epochs(epochs): | |||||||||
|
||||||||||
epochs = np.hstack(epochs) | ||||||||||
n_samples_tot = epochs.shape[-1] | ||||||||||
_check_n_samples(n_samples_tot, len(picks_meeg)) | ||||||||||
_check_n_samples(n_samples_tot, len(picks_meeg), on_few_samples) | ||||||||||
|
||||||||||
epochs = epochs.T # sklearn | C-order | ||||||||||
cov_data = _compute_covariance_auto( | ||||||||||
|
@@ -1158,6 +1169,7 @@ def _unpack_epochs(epochs): | |||||||||
picks_list=picks_list, | ||||||||||
scalings=scalings, | ||||||||||
rank=rank, | ||||||||||
on_few_samples=on_few_samples, | ||||||||||
) | ||||||||||
|
||||||||||
if keep_sample_mean is False: | ||||||||||
|
@@ -1221,7 +1233,7 @@ def _eigvec_subspace(eig, eigvec, mask): | |||||||||
|
||||||||||
@verbose | ||||||||||
def _compute_rank_raw_array( | ||||||||||
data, info, rank, scalings, *, log_ch_type=None, verbose=None | ||||||||||
data, info, rank, scalings, *, log_ch_type=None, on_few_samples="warn", verbose=None | ||||||||||
): | ||||||||||
from .io import RawArray | ||||||||||
|
||||||||||
|
@@ -1231,6 +1243,7 @@ def _compute_rank_raw_array( | |||||||||
scalings, | ||||||||||
info, | ||||||||||
log_ch_type=log_ch_type, | ||||||||||
on_few_samples=on_few_samples, | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -1249,6 +1262,7 @@ def _compute_covariance_auto( | |||||||||
cov_kind="", | ||||||||||
log_ch_type=None, | ||||||||||
log_rank=True, | ||||||||||
on_few_samples="warn", | ||||||||||
): | ||||||||||
"""Compute covariance auto mode.""" | ||||||||||
# rescale to improve numerical stability | ||||||||||
|
@@ -1258,6 +1272,7 @@ def _compute_covariance_auto( | |||||||||
info, | ||||||||||
rank=rank, | ||||||||||
scalings=scalings, | ||||||||||
on_few_samples=on_few_samples, | ||||||||||
verbose=_verbose_safe_false(), | ||||||||||
) | ||||||||||
with _scaled_array(data.T, picks_list, scalings): | ||||||||||
|
@@ -1268,6 +1283,7 @@ def _compute_covariance_auto( | |||||||||
rank, | ||||||||||
proj_subspace=True, | ||||||||||
do_compute_rank=False, | ||||||||||
on_few_samples=on_few_samples, | ||||||||||
log_ch_type=log_ch_type, | ||||||||||
verbose=None if log_rank else _verbose_safe_false(), | ||||||||||
) | ||||||||||
|
@@ -1729,6 +1745,7 @@ def prepare_noise_cov( | |||||||||
rank=None, | ||||||||||
scalings=None, | ||||||||||
on_rank_mismatch="ignore", | ||||||||||
on_few_samples="warn", | ||||||||||
verbose=None, | ||||||||||
): | ||||||||||
"""Prepare noise covariance matrix. | ||||||||||
|
@@ -1751,6 +1768,10 @@ def prepare_noise_cov( | |||||||||
|
||||||||||
dict(mag=1e12, grad=1e11, eeg=1e5) | ||||||||||
%(on_rank_mismatch)s | ||||||||||
on_few_samples : str | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only functions that compute a covariance should end up warning about the number of samples. Things like |
||||||||||
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when | ||||||||||
there are fewer samples than channels, which can lead to inaccurate | ||||||||||
covariance or rank estimates. | ||||||||||
%(verbose)s | ||||||||||
|
||||||||||
Returns | ||||||||||
|
@@ -1792,6 +1813,7 @@ def prepare_noise_cov( | |||||||||
projs, | ||||||||||
ch_names, | ||||||||||
on_rank_mismatch=on_rank_mismatch, | ||||||||||
on_few_samples=on_few_samples, | ||||||||||
) | ||||||||||
noise_cov.update(eig=eig, eigvec=eigvec) | ||||||||||
return noise_cov | ||||||||||
|
@@ -1808,6 +1830,7 @@ def _smart_eigh( | |||||||||
proj_subspace=False, | ||||||||||
do_compute_rank=True, | ||||||||||
on_rank_mismatch="ignore", | ||||||||||
on_few_samples="warn", | ||||||||||
*, | ||||||||||
log_ch_type=None, | ||||||||||
verbose=None, | ||||||||||
|
@@ -1838,6 +1861,7 @@ def _smart_eigh( | |||||||||
scalings, | ||||||||||
info, | ||||||||||
on_rank_mismatch=on_rank_mismatch, | ||||||||||
on_few_samples=on_few_samples, | ||||||||||
log_ch_type=log_ch_type, | ||||||||||
) | ||||||||||
assert C.ndim == 2 and C.shape[0] == C.shape[1] | ||||||||||
|
@@ -1916,6 +1940,7 @@ def regularize( | |||||||||
dbs=0.1, | ||||||||||
rank=None, | ||||||||||
scalings=None, | ||||||||||
on_few_samples="warn", | ||||||||||
verbose=None, | ||||||||||
): | ||||||||||
"""Regularize noise covariance matrix. | ||||||||||
|
@@ -1978,6 +2003,10 @@ def regularize( | |||||||||
See :func:`mne.compute_covariance`. | ||||||||||
|
||||||||||
.. versionadded:: 0.17 | ||||||||||
on_few_samples : str | ||||||||||
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when | ||||||||||
there are fewer samples than channels, which can lead to inaccurate | ||||||||||
covariance or rank estimates. | ||||||||||
%(verbose)s | ||||||||||
|
||||||||||
Returns | ||||||||||
|
@@ -2032,7 +2061,7 @@ def regularize( | |||||||||
else: | ||||||||||
regs.update(mag=mag, grad=grad) | ||||||||||
if rank != "full": | ||||||||||
rank = _compute_rank(cov, rank, scalings, info) | ||||||||||
rank = _compute_rank(cov, rank, scalings, info, on_few_samples=on_few_samples) | ||||||||||
|
||||||||||
info_ch_names = info["ch_names"] | ||||||||||
ch_names_by_type = dict() | ||||||||||
|
@@ -2092,7 +2121,9 @@ def regularize( | |||||||||
this_info = pick_info(info, this_picks) | ||||||||||
# Here we could use proj_subspace=True, but this should not matter | ||||||||||
# since this is already in a loop over channel types | ||||||||||
_, eigvec, mask = _smart_eigh(this_C, this_info, rank) | ||||||||||
_, eigvec, mask = _smart_eigh( | ||||||||||
this_C, this_info, rank, on_few_samples=on_few_samples | ||||||||||
) | ||||||||||
U = eigvec[mask].T | ||||||||||
this_C = np.dot(U.T, np.dot(this_C, U)) | ||||||||||
|
||||||||||
|
@@ -2119,6 +2150,7 @@ def _regularized_covariance( | |||||||||
log_ch_type=None, | ||||||||||
log_rank=None, | ||||||||||
cov_kind="", | ||||||||||
on_few_samples="warn", | ||||||||||
verbose=None, | ||||||||||
): | ||||||||||
"""Compute a regularized covariance from data using sklearn. | ||||||||||
|
@@ -2166,6 +2198,7 @@ def _regularized_covariance( | |||||||||
cov_kind=cov_kind, | ||||||||||
log_ch_type=log_ch_type, | ||||||||||
log_rank=log_rank, | ||||||||||
on_few_samples=on_few_samples, | ||||||||||
)[reg]["data"] | ||||||||||
return cov | ||||||||||
|
||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nowadays we try to add
*
to functions (and we've been adding them to old functions as we come across them), likeBut really we could probably go with something much farther up... maybe after
picks
and beforemethod
?