Skip to content

Commit 043c8be

Browse files
Genusterlarsoner
andauthored
BUG: Fix GED tolerances and improve cov validation (#13346)
Co-authored-by: Eric Larson <[email protected]>
1 parent bf57d9c commit 043c8be

File tree

3 files changed

+121
-60
lines changed

3 files changed

+121
-60
lines changed

mne/decoding/_ged.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,39 @@ def _smart_ged(S, R, restr_mat=None, R_func=None):
5454
return evals, evecs
5555

5656

57-
def _is_cov_symm_pos_semidef(
58-
cov, rtol=1e-10, atol=1e-11, eval_tol=1e-15, check_pos_semidef=True
59-
):
57+
def _is_cov_symm(cov, rtol=1e-7, atol=None):
58+
if atol is None:
59+
atol = 1e-7 * np.max(np.abs(cov))
6060
is_symm = scipy.linalg.issymmetric(cov, rtol=rtol, atol=atol)
61-
if not is_symm:
62-
return False
61+
return is_symm
6362

64-
if check_pos_semidef:
65-
# numerically slightly negative evals are considered 0
66-
is_pos_semidef = np.all(scipy.linalg.eigvalsh(cov) >= -eval_tol)
67-
return is_pos_semidef
6863

69-
return True
64+
def _get_cov_def(cov, eval_tol=None):
65+
"""Get definiteness of symmetric cov matrix.
7066
67+
All evals in (-eval_tol, eval_tol) will be considered zero,
68+
while all evals smaller than -eval_tol will be considered
69+
negative.
70+
"""
71+
evals = scipy.linalg.eigvalsh(cov)
72+
if eval_tol is None:
73+
eval_tol = 1e-7 * np.max(np.abs(evals))
74+
if np.all(evals > eval_tol):
75+
return "pos_def"
76+
elif np.all(evals >= -eval_tol):
77+
return "pos_semidef"
78+
else:
79+
return "indef"
80+
81+
82+
def _is_cov_pos_semidef(cov, eval_tol=None):
83+
cov_def = _get_cov_def(cov, eval_tol=eval_tol)
84+
return cov_def in ("pos_def", "pos_semidef")
7185

72-
def _is_cov_pos_def(cov, eval_tol=1e-15):
73-
is_symm = _is_cov_symm_pos_semidef(cov, check_pos_semidef=False)
74-
if not is_symm:
75-
return False
76-
# numerically slightly positive evals are considered 0
77-
is_pos_def = np.all(scipy.linalg.eigvalsh(cov) > eval_tol)
78-
return is_pos_def
86+
87+
def _is_cov_pos_def(cov, eval_tol=None):
88+
cov_def = _get_cov_def(cov, eval_tol=eval_tol)
89+
return cov_def == "pos_def"
7990

8091

8192
def _smart_ajd(covs, restr_mat=None, weights=None):
@@ -90,8 +101,8 @@ def _smart_ajd(covs, restr_mat=None, weights=None):
90101
from .csp import _ajd_pham
91102

92103
if restr_mat is None:
93-
is_all_pos_def = all([_is_cov_pos_def(cov) for cov in covs])
94-
if not is_all_pos_def:
104+
are_all_pos_def = all([_is_cov_pos_def(cov) for cov in covs])
105+
if not are_all_pos_def:
95106
raise ValueError(
96107
"If C_ref is not provided by covariance estimator, "
97108
"all the covs should be positive definite"
@@ -100,6 +111,12 @@ def _smart_ajd(covs, restr_mat=None, weights=None):
100111
return evecs
101112

102113
else:
114+
are_all_pos_semidef = all([_is_cov_pos_semidef(cov) for cov in covs])
115+
if not are_all_pos_semidef:
116+
raise ValueError(
117+
"All the covs should be positive semi-definite for "
118+
"approximate joint diagonalization"
119+
)
103120
covs = np.array([restr_mat @ cov @ restr_mat.T for cov in covs], float)
104121
evecs_restr, D = _ajd_pham(covs)
105122
evecs = _normalize_eigenvectors(evecs_restr.T, covs, weights)

mne/decoding/base.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@
2626

2727
from ..parallel import parallel_func
2828
from ..utils import _check_option, _pl, _validate_type, logger, pinv, verbose, warn
29-
from ._ged import _handle_restr_mat, _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged
29+
from ._ged import (
30+
_handle_restr_mat,
31+
_is_cov_pos_semidef,
32+
_is_cov_symm,
33+
_smart_ajd,
34+
_smart_ged,
35+
)
3036
from ._mod_ged import _no_op_mod
3137
from .transformer import MNETransformerMixin
3238

@@ -133,23 +139,21 @@ def fit(self, X, y=None):
133139
covs, C_ref, info, rank, kwargs = self.cov_callable(X, y)
134140
covs = np.stack(covs)
135141
self._validate_covariances(covs)
136-
self._validate_covariances([C_ref])
142+
if C_ref is not None:
143+
self._validate_covariances([C_ref])
137144
mod_ged_callable = (
138145
self.mod_ged_callable if self.mod_ged_callable is not None else _no_op_mod
139146
)
147+
restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank)
140148

141149
if self.dec_type == "single":
142150
if len(covs) > 2:
143-
weights = (
144-
kwargs["sample_weights"] if "sample_weights" in kwargs else None
145-
)
146-
restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank)
151+
weights = kwargs.get("sample_weights", None)
147152
evecs = _smart_ajd(covs, restr_mat, weights=weights)
148153
evals = None
149154
else:
150155
S = covs[0]
151156
R = covs[1]
152-
restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank)
153157
evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func)
154158

155159
evals, evecs, self.sorter_ = mod_ged_callable(evals, evecs, covs, **kwargs)
@@ -160,7 +164,6 @@ def fit(self, X, y=None):
160164
elif self.dec_type == "multi":
161165
self.classes_ = np.unique(y)
162166
R = covs[-1]
163-
restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank)
164167
all_evals, all_evecs = list(), list()
165168
all_patterns, all_sorters = list(), list()
166169
for i in range(len(self.classes_)):
@@ -251,18 +254,66 @@ def _validate_ged_params(self):
251254
)
252255

253256
def _validate_covariances(self, covs):
254-
for cov in covs:
255-
if cov is None:
256-
continue
257-
# XXX: A lot of mne.decoding classes use mne.cov._regularized_covariance.
258-
# Depending on the data it sometimes returns negative semidefinite matrices.
259-
# So adding the validation of positive semidefinitiveness
260-
# will require overhauling covariance estimation first.
261-
is_cov = _is_cov_symm_pos_semidef(cov, check_pos_semidef=False)
262-
if not is_cov:
257+
error_template = (
258+
"{matrix} is not {prop}, but required to be for {decomp}. "
259+
"Check your cov_callable"
260+
)
261+
if len(covs) == 1:
262+
C_ref = covs[0]
263+
is_C_ref_symm = _is_cov_symm(C_ref)
264+
if not is_C_ref_symm:
265+
raise ValueError(
266+
error_template.format(
267+
matrix="C_ref covariance",
268+
prop="symmetric",
269+
decomp="decomposition",
270+
)
271+
)
272+
elif self.dec_type == "single" and len(covs) > 2:
273+
# make only lenient symmetric check here.
274+
# positive semidefiniteness/definiteness will be
275+
# checked inside _smart_ajd
276+
for ci, cov in enumerate(covs):
277+
if not _is_cov_symm(cov):
278+
raise ValueError(
279+
error_template.format(
280+
matrix=f"cov[{ci}]",
281+
prop="symmetric",
282+
decomp="approximate joint diagonalization",
283+
)
284+
)
285+
else:
286+
if len(covs) == 2:
287+
S_covs = [covs[0]]
288+
R = covs[1]
289+
elif self.dec_type == "multi":
290+
S_covs = covs[:-1]
291+
R = covs[-1]
292+
293+
are_all_S_symm = all([_is_cov_symm(S) for S in S_covs])
294+
if not are_all_S_symm:
295+
raise ValueError(
296+
error_template.format(
297+
matrix="S covariance",
298+
prop="symmetric",
299+
decomp="generalized eigendecomposition",
300+
)
301+
)
302+
if not _is_cov_symm(R):
303+
raise ValueError(
304+
error_template.format(
305+
matrix="R covariance",
306+
prop="symmetric",
307+
decomp="generalized eigendecomposition",
308+
)
309+
)
310+
if not _is_cov_pos_semidef(R):
263311
raise ValueError(
264-
"One of covariances is not symmetric (or positive semidefinite), "
265-
"check your cov_callable"
312+
error_template.format(
313+
matrix="R covariance",
314+
prop="positive semi-definite",
315+
decomp="generalized eigendecomposition",
316+
)
266317
)
267318

268319
def __sklearn_tags__(self):

mne/decoding/tests/test_ged.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from mne._fiff.proj import make_eeg_average_ref_proj
2020
from mne.cov import Covariance, _regularized_covariance
2121
from mne.decoding._ged import (
22+
_get_cov_def,
2223
_get_restr_mat,
2324
_handle_restr_mat,
24-
_is_cov_pos_def,
25-
_is_cov_symm_pos_semidef,
25+
_is_cov_symm,
2626
_smart_ajd,
2727
_smart_ged,
2828
)
@@ -345,34 +345,27 @@ def test__handle_restr_mat_invalid_restr_type():
345345

346346
def test_cov_validators():
347347
"""Test that covariance validators indeed validate."""
348-
asymm = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
348+
asymm_indef = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
349349
sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]])
350350
pos_def = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]])
351351

352-
assert not _is_cov_symm_pos_semidef(asymm)
353-
assert _is_cov_symm_pos_semidef(sing_pos_semidef)
354-
assert _is_cov_symm_pos_semidef(pos_def)
352+
assert not _is_cov_symm(asymm_indef)
353+
assert _get_cov_def(asymm_indef) == "indef"
354+
assert _get_cov_def(sing_pos_semidef) == "pos_semidef"
355+
assert _get_cov_def(pos_def) == "pos_def"
355356

356-
assert not _is_cov_pos_def(asymm)
357-
assert not _is_cov_pos_def(sing_pos_semidef)
358-
assert _is_cov_pos_def(pos_def)
359357

360-
361-
def test__is_cov_pos_def():
362-
"""Test _is_cov_pos_def works."""
363-
asymm = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
364-
sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]])
365-
pos_def = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]])
366-
assert not _is_cov_pos_def(asymm)
367-
assert not _is_cov_pos_def(sing_pos_semidef)
368-
assert _is_cov_pos_def(pos_def)
369-
370-
371-
def test__smart_ajd_when_restr_mat_is_none():
372-
"""Test _smart_ajd raises ValueError when restr_mat is None."""
358+
def test__smart_ajd_raises():
359+
"""Test _smart_ajd raises proper ValueErrors."""
360+
asymm_indef = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
373361
sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]])
374362
pos_def1 = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]])
375363
pos_def2 = np.array([[10, 1, 2], [1, 12, 3], [2, 3, 15]])
364+
365+
bad_covs = np.stack([sing_pos_semidef, asymm_indef, pos_def1])
366+
with pytest.raises(ValueError, match="positive semi-definite"):
367+
_smart_ajd(bad_covs, restr_mat=pos_def2, weights=None)
368+
376369
bad_covs = np.stack([sing_pos_semidef, pos_def1, pos_def2])
377370
with pytest.raises(ValueError, match="positive definite"):
378371
_smart_ajd(bad_covs, restr_mat=None, weights=None)

0 commit comments

Comments
 (0)