Skip to content

Commit 3fbf986

Browse files
emmanuel-ferdmanpre-commit-ci[bot]larsoner
authored
FIX: Add on_few_samples parameter to core rank estimation (#13350)
Signed-off-by: Emmanuel Ferdman <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson <[email protected]>
1 parent 08ef1c7 commit 3fbf986

File tree

8 files changed

+105
-17
lines changed

8 files changed

+105
-17
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add ``on_few_samples`` parameter to :func:`mne.compute_covariance` and :func:`mne.compute_raw_covariance` for controlling behavior when there are fewer samples than channels, which can lead to inaccurate covariance estimates, by :newcontrib:`Emmanuel Ferdman`.

doc/changes/names.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
.. _Eduard Ort: https://github.com/eort
7777
.. _Emily Stephen: https://github.com/emilyps14
7878
.. _Emma Bailey: https://www.cbs.mpg.de/employees/bailey
79+
.. _Emmanuel Ferdman: https://github.com/emmanuel-ferdman
7980
.. _Emrecan Çelik: https://github.com/emrecncelik
8081
.. _Enrico Varano: https://github.com/enricovara/
8182
.. _Enzo Altamiranda: https://www.linkedin.com/in/enzoalt

mne/cov.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -553,16 +553,17 @@ def make_ad_hoc_cov(info, std=None, *, verbose=None):
553553
return Covariance(data, ch_names, info["bads"], info["projs"], nfree=0)
554554

555555

556-
def _check_n_samples(n_samples, n_chan):
556+
def _check_n_samples(n_samples, n_chan, on_few_samples="warn"):
557557
"""Check to see if there are enough samples for reliable cov calc."""
558558
n_samples_min = 10 * (n_chan + 1) // 2
559559
if n_samples <= 0:
560560
raise ValueError("No samples found to compute the covariance matrix")
561561
if n_samples < n_samples_min:
562-
warn(
563-
f"Too few samples (required : {n_samples_min} got : {n_samples}), "
564-
"covariance estimate may be unreliable"
562+
msg = (
563+
f"Too few samples (required {n_samples_min} but got {n_samples} for "
564+
f"{n_chan} channels), covariance estimate may be unreliable"
565565
)
566+
_on_missing(on_few_samples, msg, "on_few_samples")
566567

567568

568569
@verbose
@@ -574,6 +575,8 @@ def compute_raw_covariance(
574575
reject=None,
575576
flat=None,
576577
picks=None,
578+
*,
579+
on_few_samples="warn",
577580
method="empirical",
578581
method_params=None,
579582
cv=3,
@@ -623,6 +626,12 @@ def compute_raw_covariance(
623626
are floats that set the minimum acceptable peak-to-peak amplitude.
624627
If flat is None then no rejection is done.
625628
%(picks_good_data_noref)s
629+
on_few_samples : str
630+
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when
631+
there are fewer samples than channels, which can lead to inaccurate
632+
covariance or rank estimates.
633+
634+
.. versionadded:: 1.11
626635
method : str | list | None (default 'empirical')
627636
The method used for covariance estimation.
628637
See :func:`mne.compute_covariance`.
@@ -736,7 +745,7 @@ def compute_raw_covariance(
736745
mu += raw_segment.sum(axis=1)
737746
data += np.dot(raw_segment, raw_segment.T)
738747
n_samples += raw_segment.shape[1]
739-
_check_n_samples(n_samples, len(picks))
748+
_check_n_samples(n_samples, len(picks), on_few_samples)
740749
data -= mu[:, None] * (mu[None, :] / n_samples)
741750
data /= n_samples - 1.0
742751
logger.info("Number of samples used : %d", n_samples)
@@ -864,6 +873,8 @@ def compute_covariance(
864873
tmin=None,
865874
tmax=None,
866875
projs=None,
876+
*,
877+
on_few_samples="warn",
867878
method="empirical",
868879
method_params=None,
869880
cv=3,
@@ -909,6 +920,12 @@ def compute_covariance(
909920
List of projectors to use in covariance calculation, or None
910921
to indicate that the projectors from the epochs should be
911922
inherited. If None, then projectors from all epochs must match.
923+
on_few_samples : str
924+
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when
925+
there are fewer samples than channels, which can lead to inaccurate
926+
covariance or rank estimates.
927+
928+
.. versionadded:: 1.11
912929
method : str | list | None (default 'empirical')
913930
The method used for covariance estimation. If 'empirical' (default),
914931
the sample covariance will be computed. A list can be passed to
@@ -1144,7 +1161,7 @@ def _unpack_epochs(epochs):
11441161

11451162
epochs = np.hstack(epochs)
11461163
n_samples_tot = epochs.shape[-1]
1147-
_check_n_samples(n_samples_tot, len(picks_meeg))
1164+
_check_n_samples(n_samples_tot, len(picks_meeg), on_few_samples)
11481165

11491166
epochs = epochs.T # sklearn | C-order
11501167
cov_data = _compute_covariance_auto(
@@ -1158,6 +1175,7 @@ def _unpack_epochs(epochs):
11581175
picks_list=picks_list,
11591176
scalings=scalings,
11601177
rank=rank,
1178+
on_few_samples=on_few_samples,
11611179
)
11621180

11631181
if keep_sample_mean is False:
@@ -1221,7 +1239,7 @@ def _eigvec_subspace(eig, eigvec, mask):
12211239

12221240
@verbose
12231241
def _compute_rank_raw_array(
1224-
data, info, rank, scalings, *, log_ch_type=None, verbose=None
1242+
data, info, rank, scalings, *, log_ch_type=None, on_few_samples="warn", verbose=None
12251243
):
12261244
from .io import RawArray
12271245

@@ -1231,6 +1249,7 @@ def _compute_rank_raw_array(
12311249
scalings,
12321250
info,
12331251
log_ch_type=log_ch_type,
1252+
on_few_samples=on_few_samples,
12341253
)
12351254

12361255

@@ -1249,6 +1268,7 @@ def _compute_covariance_auto(
12491268
cov_kind="",
12501269
log_ch_type=None,
12511270
log_rank=True,
1271+
on_few_samples="warn",
12521272
):
12531273
"""Compute covariance auto mode."""
12541274
# rescale to improve numerical stability
@@ -1258,6 +1278,7 @@ def _compute_covariance_auto(
12581278
info,
12591279
rank=rank,
12601280
scalings=scalings,
1281+
on_few_samples=on_few_samples,
12611282
verbose=_verbose_safe_false(),
12621283
)
12631284
with _scaled_array(data.T, picks_list, scalings):
@@ -1729,6 +1750,7 @@ def prepare_noise_cov(
17291750
rank=None,
17301751
scalings=None,
17311752
on_rank_mismatch="ignore",
1753+
*,
17321754
verbose=None,
17331755
):
17341756
"""Prepare noise covariance matrix.
@@ -2119,6 +2141,9 @@ def _regularized_covariance(
21192141
log_ch_type=None,
21202142
log_rank=None,
21212143
cov_kind="",
2144+
# backward-compat default for decoding (maybe someday we want to expose this but
2145+
# it's likely too invasive and since it's usually regularized, unnecessary):
2146+
on_few_samples="ignore",
21222147
verbose=None,
21232148
):
21242149
"""Compute a regularized covariance from data using sklearn.
@@ -2166,6 +2191,7 @@ def _regularized_covariance(
21662191
cov_kind=cov_kind,
21672192
log_ch_type=log_ch_type,
21682193
log_rank=log_rank,
2194+
on_few_samples=on_few_samples,
21692195
)[reg]["data"]
21702196
return cov
21712197

mne/decoding/_covs_ged.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def _csp_estimate(X, y, reg, cov_method_params, cov_est, info, rank, norm_trace)
8888
rank=rank,
8989
scalings=None,
9090
log_ch_type="data",
91+
on_few_samples="ignore",
9192
)
9293

9394
covs = []
@@ -158,6 +159,7 @@ def _xdawn_estimate(
158159
rank=rank,
159160
scalings=None,
160161
log_ch_type="data",
162+
on_few_samples="ignore",
161163
)
162164
return covs, C_ref, info, rank, dict()
163165

@@ -280,5 +282,6 @@ def _spoc_estimate(X, y, reg, cov_method_params, info, rank):
280282
rank=rank,
281283
scalings=None,
282284
log_ch_type="data",
285+
on_few_samples="ignore",
283286
)
284287
return covs, C_ref, info, rank, dict()

mne/decoding/tests/test_csp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def test_spoc():
431431
# check y
432432
pytest.raises(ValueError, spoc.fit, X, y * 0)
433433

434-
# Check that doesn't take CSP-spcific input
434+
# Check that doesn't take CSP-specific input
435435
pytest.raises(TypeError, SPoC, cov_est="epoch")
436436

437437
# Check mixing matrix on simulated data

mne/rank.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,27 @@ def _estimate_rank_from_s(s, tol="auto", tol_kind="absolute"):
130130

131131

132132
def _estimate_rank_raw(
133-
raw, picks=None, tol=1e-4, scalings="norm", with_ref_meg=False, tol_kind="absolute"
133+
raw,
134+
picks=None,
135+
tol=1e-4,
136+
scalings="norm",
137+
with_ref_meg=False,
138+
tol_kind="absolute",
139+
on_few_samples="warn",
134140
):
135141
"""Aid the transition away from raw.estimate_rank."""
136142
if picks is None:
137143
picks = _picks_to_idx(raw.info, picks, with_ref_meg=with_ref_meg)
138144
# conveniency wrapper to expose the expert "tol" option + scalings options
139145
return _estimate_rank_meeg_signals(
140-
raw[picks][0], pick_info(raw.info, picks), scalings, tol, False, tol_kind
146+
raw[picks][0],
147+
pick_info(raw.info, picks),
148+
scalings,
149+
tol,
150+
False,
151+
tol_kind,
152+
log_ch_type=None,
153+
on_few_samples=on_few_samples,
141154
)
142155

143156

@@ -150,6 +163,7 @@ def _estimate_rank_meeg_signals(
150163
return_singular=False,
151164
tol_kind="absolute",
152165
log_ch_type=None,
166+
on_few_samples="warn",
153167
):
154168
"""Estimate rank for M/EEG data.
155169
@@ -173,6 +187,10 @@ def _estimate_rank_meeg_signals(
173187
to determine the rank.
174188
tol_kind : str
175189
Tolerance kind. See ``estimate_rank``.
190+
on_few_samples : str
191+
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when
192+
there are fewer samples than channels, which can lead to inaccurate rank
193+
estimates.
176194
177195
Returns
178196
-------
@@ -183,11 +201,14 @@ def _estimate_rank_meeg_signals(
183201
thresholded to determine the rank are also returned.
184202
"""
185203
picks_list = _picks_by_type(info)
186-
if data.shape[1] < data.shape[0]:
187-
ValueError(
188-
"You've got fewer samples than channels, your "
189-
"rank estimate might be inaccurate."
204+
assert data.ndim == 2, data.shape
205+
n_channels, n_samples = data.shape
206+
if n_samples < n_channels:
207+
msg = (
208+
f"Too few samples ({n_samples=} is less than {n_channels=}), "
209+
"rank estimate may be unreliable"
190210
)
211+
_on_missing(on_few_samples, msg, "on_few_samples")
191212
with _scaled_array(data, picks_list, scalings):
192213
out = estimate_rank(
193214
data,
@@ -214,6 +235,7 @@ def _estimate_rank_meeg_cov(
214235
return_singular=False,
215236
*,
216237
log_ch_type=None,
238+
on_few_samples="warn",
217239
verbose=None,
218240
):
219241
"""Estimate rank of M/EEG covariance data, given the covariance.
@@ -236,6 +258,10 @@ def _estimate_rank_meeg_cov(
236258
return_singular : bool
237259
If True, also return the singular values that were used
238260
to determine the rank.
261+
on_few_samples : str
262+
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when
263+
there are fewer samples than channels, which can lead to inaccurate rank
264+
estimates.
239265
240266
Returns
241267
-------
@@ -249,10 +275,11 @@ def _estimate_rank_meeg_cov(
249275
scalings = _handle_default("scalings_cov_rank", scalings)
250276
_apply_scaling_cov(data, picks_list, scalings)
251277
if data.shape[1] < data.shape[0]:
252-
ValueError(
278+
msg = (
253279
"You've got fewer samples than channels, your "
254280
"rank estimate might be inaccurate."
255281
)
282+
_on_missing(on_few_samples, msg, "on_few_samples")
256283
out = estimate_rank(data, tol=tol, norm=False, return_singular=return_singular)
257284
rank = out[0] if isinstance(out, tuple) else out
258285
if log_ch_type is None:
@@ -325,7 +352,7 @@ def _compute_rank_int(inst, *args, **kwargs):
325352
# XXX eventually we should unify how channel types are handled
326353
# so that we don't need to do this, or we do it everywhere.
327354
# Using pca=True in compute_whitener might help.
328-
return sum(compute_rank(inst, *args, **kwargs).values())
355+
return sum(compute_rank(inst, *args, on_few_samples="ignore", **kwargs).values())
329356

330357

331358
@verbose
@@ -335,9 +362,11 @@ def compute_rank(
335362
scalings=None,
336363
info=None,
337364
tol="auto",
365+
*,
338366
proj=True,
339367
tol_kind="absolute",
340368
on_rank_mismatch="ignore",
369+
on_few_samples=None,
341370
verbose=None,
342371
):
343372
"""Compute the rank of data or noise covariance.
@@ -363,6 +392,13 @@ def compute_rank(
363392
considered when ``rank=None`` or ``rank='info'``.
364393
%(tol_kind_rank)s
365394
%(on_rank_mismatch)s
395+
on_few_samples : str | None
396+
Can be 'warn', 'ignore', or 'raise' to control behavior when
397+
there are fewer samples than channels, which can lead to inaccurate rank
398+
estimates. None (default) means "ignore" if ``inst`` is a
399+
:class:`mne.Covariance` or ``rank in ("info", "full")``, and "warn" otherwise.
400+
401+
.. versionadded:: 1.11
366402
%(verbose)s
367403
368404
Returns
@@ -384,6 +420,7 @@ def compute_rank(
384420
proj=proj,
385421
tol_kind=tol_kind,
386422
on_rank_mismatch=on_rank_mismatch,
423+
on_few_samples=on_few_samples,
387424
)
388425

389426

@@ -398,6 +435,7 @@ def _compute_rank(
398435
proj=True,
399436
tol_kind="absolute",
400437
on_rank_mismatch="ignore",
438+
on_few_samples=None,
401439
log_ch_type=None,
402440
verbose=None,
403441
):
@@ -441,6 +479,12 @@ def _compute_rank(
441479
if rank is None:
442480
rank = dict()
443481

482+
if on_few_samples is None:
483+
if inst_type != "covariance" and rank_type == "estimated":
484+
on_few_samples = "warn"
485+
else:
486+
on_few_samples = "ignore"
487+
444488
simple_info = _simplify_info(info)
445489
picks_list = _picks_by_type(info, meg_combined=True, ref_meg=False, exclude="bads")
446490
for ch_type, picks in picks_list:
@@ -503,6 +547,7 @@ def _compute_rank(
503547
False,
504548
tol_kind,
505549
log_ch_type=log_ch_type,
550+
on_few_samples=on_few_samples,
506551
)
507552
else:
508553
assert isinstance(inst, Covariance)
@@ -520,6 +565,7 @@ def _compute_rank(
520565
tol,
521566
return_singular=True,
522567
log_ch_type=log_ch_type,
568+
on_few_samples=on_few_samples,
523569
verbose=est_verbose,
524570
)
525571
if ch_type in rank:

mne/simulation/tests/test_evoked.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def test_add_noise():
137137
if inst is raw:
138138
cov_new = compute_raw_covariance(inst, picks=picks)
139139
else:
140-
cov_new = compute_covariance(inst)
140+
with pytest.warns(RuntimeWarning, match=".*Too few samples.*"):
141+
cov_new = compute_covariance(inst)
141142
assert cov["names"] == cov_new["names"]
142143
r = np.corrcoef(cov["data"].ravel(), cov_new["data"].ravel())[0, 1]
143144
assert r > 0.99

0 commit comments

Comments
 (0)