Skip to content

Commit b9780bf

Browse files
committed
BUG: Fix bug with plot_white
1 parent 0a88d41 commit b9780bf

File tree

6 files changed

+66
-25
lines changed

6 files changed

+66
-25
lines changed

doc/changes/dev/13595.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug where :func:`mne.viz.plot_evoked_white` did not accept a single "meg" rank value like those returned from :func:`mne.compute_rank`, by `Eric Larson`_.

mne/minimum_norm/tests/test_inverse.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
EvokedArray,
2424
SourceEstimate,
2525
combine_evoked,
26+
compute_rank,
2627
compute_raw_covariance,
2728
convert_forward_solution,
2829
make_ad_hoc_cov,
@@ -992,21 +993,34 @@ def test_make_inverse_operator_diag(evoked, noise_cov, tmp_path, azure_windows):
992993

993994
def test_inverse_operator_noise_cov_rank(evoked, noise_cov):
994995
"""Test MNE inverse operator with a specified noise cov rank."""
995-
fwd_op = read_forward_solution_meg(fname_fwd, surf_ori=True)
996-
inv = make_inverse_operator(evoked.info, fwd_op, noise_cov, rank=dict(meg=64))
996+
fwd_op_meg = read_forward_solution_meg(fname_fwd, surf_ori=True)
997+
inv = make_inverse_operator(evoked.info, fwd_op_meg, noise_cov, rank=dict(meg=64))
997998
assert compute_rank_inverse(inv) == 64
998-
inv = make_inverse_operator(evoked.info, fwd_op, noise_cov, rank=dict(meg=64))
999+
inv = make_inverse_operator(evoked.info, fwd_op_meg, noise_cov, rank=dict(meg=64))
9991000
assert compute_rank_inverse(inv) == 64
10001001

10011002
bad_cov = noise_cov.copy()
10021003
bad_cov["data"][0, 0] *= 1e12
10031004
with pytest.warns(RuntimeWarning, match="orders of magnitude"):
1004-
make_inverse_operator(evoked.info, fwd_op, bad_cov, rank=dict(meg=64))
1005+
make_inverse_operator(evoked.info, fwd_op_meg, bad_cov, rank=dict(meg=64))
10051006

1006-
fwd_op = read_forward_solution_eeg(fname_fwd, surf_ori=True)
1007-
inv = make_inverse_operator(evoked.info, fwd_op, noise_cov, rank=dict(eeg=20))
1007+
fwd_op_eeg = read_forward_solution_eeg(fname_fwd, surf_ori=True)
1008+
inv = make_inverse_operator(evoked.info, fwd_op_eeg, noise_cov, rank=dict(eeg=20))
10081009
assert compute_rank_inverse(inv) == 20
10091010

1011+
# with and without rank passed explicitly
1012+
inv_info = make_inverse_operator(evoked.info, fwd_op_meg, noise_cov, rank="info")
1013+
info_rank = 302
1014+
assert compute_rank_inverse(inv_info) == info_rank
1015+
rank = compute_rank(noise_cov, info=evoked.copy().pick("meg").info, rank="info")
1016+
assert "meg" in rank
1017+
assert sum(rank.values()) == info_rank
1018+
inv_rank = make_inverse_operator(evoked.info, fwd_op_meg, noise_cov, rank=rank)
1019+
assert compute_rank_inverse(inv_rank) == info_rank
1020+
evoked_info = apply_inverse(evoked, inv_info, lambda2, "MNE")
1021+
evoked_rank = apply_inverse(evoked, inv_rank, lambda2, "MNE")
1022+
assert_allclose(evoked_rank.data, evoked_info.data)
1023+
10101024

10111025
def test_inverse_operator_volume(evoked, tmp_path):
10121026
"""Test MNE inverse computation on volume source space."""

mne/tests/test_cov.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ def test_compute_whitener(proj, pca):
9494
assert pca is False
9595
assert_allclose(round_trip, np.eye(n_channels), atol=0.05)
9696

97+
# with and without rank
98+
W_info, _ = compute_whitener(cov, raw.info, pca=pca, rank="info", verbose="error")
99+
assert_allclose(W_info, W)
100+
rank = compute_rank(raw, rank="info", proj=proj)
101+
assert W.shape == (n_reduced, n_channels)
102+
W_rank, _ = compute_whitener(cov, raw.info, pca=pca, rank=rank, verbose="error")
103+
assert_allclose(W_rank, W)
104+
97105
raw.info["bads"] = [raw.ch_names[0]]
98106
picks = pick_types(raw.info, meg=True, eeg=True, exclude=[])
99107
with pytest.warns(RuntimeWarning, match="Too few samples"):

mne/viz/evoked.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,13 +1587,9 @@ def plot_evoked_white(
15871587
evoked.del_proj(idx)
15881588

15891589
evoked.pick_types(ref_meg=False, exclude="bads", **_PICK_TYPES_DATA_DICT)
1590-
n_ch_used, rank_list, picks_list, has_sss = _triage_rank_sss(
1590+
n_ch_used, rank_list, picks_list, meg_combined = _triage_rank_sss(
15911591
evoked.info, noise_cov, rank, scalings=None
15921592
)
1593-
if has_sss:
1594-
logger.info(
1595-
"SSS has been applied to data. Showing mag and grad whitening jointly."
1596-
)
15971593

15981594
# get one whitened evoked per cov
15991595
evokeds_white = [
@@ -1663,8 +1659,8 @@ def whitened_gfp(x, rank=None):
16631659
# hacks to get it to plot all channels in the same axes, namely setting
16641660
# the channel unit (most important) and coil type (for consistency) of
16651661
# all MEG channels to be the same.
1666-
meg_idx = sss_title = None
1667-
if has_sss:
1662+
meg_idx = combined_title = None
1663+
if meg_combined:
16681664
titles_["meg"] = "MEG (combined)"
16691665
meg_idx = [
16701666
pi for pi, (ch_type, _) in enumerate(picks_list) if ch_type == "meg"
@@ -1675,7 +1671,7 @@ def whitened_gfp(x, rank=None):
16751671
use = evokeds_white[0].info["chs"][picks[0]][key]
16761672
for pick in picks:
16771673
evokeds_white[0].info["chs"][pick][key] = use
1678-
sss_title = f"{titles_['meg']} ({len(picks)} channel{_pl(picks)})"
1674+
combined_title = f"{titles_['meg']} ({len(picks)} channel{_pl(picks)})"
16791675
evokeds_white[0].plot(
16801676
unit=False,
16811677
axes=axes_evoked,
@@ -1684,8 +1680,8 @@ def whitened_gfp(x, rank=None):
16841680
time_unit=time_unit,
16851681
spatial_colors=spatial_colors,
16861682
)
1687-
if has_sss:
1688-
axes_evoked[meg_idx].set(title=sss_title)
1683+
if meg_combined:
1684+
axes_evoked[meg_idx].set(title=combined_title)
16891685

16901686
# Now plot the GFP for all covs if indicated.
16911687
for evoked_white, noise_cov, rank_, color in iter_gfp:

mne/viz/tests/test_evoked.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Epochs,
1919
compute_covariance,
2020
compute_proj_evoked,
21+
compute_rank,
2122
make_fixed_length_events,
2223
read_cov,
2324
read_events,
@@ -357,6 +358,21 @@ def test_plot_evoked_image():
357358
evoked.plot_image(clim=[-4, 4])
358359

359360

361+
def test_plot_white_rank():
362+
"""Test plot_white with a combined-MEG rank arg."""
363+
cov = read_cov(cov_fname)
364+
cov["method"] = "empirical"
365+
cov["projs"] = [] # avoid warnings
366+
evoked = _get_epochs().average()
367+
evoked.set_eeg_reference("average") # Avoid warnings
368+
rank = compute_rank(evoked, "info")
369+
assert "grad" not in rank
370+
assert "mag" not in rank
371+
assert "meg" in rank
372+
evoked.plot_white(cov)
373+
evoked.plot_white(cov, rank=rank)
374+
375+
360376
def test_plot_white():
361377
"""Test plot_white."""
362378
cov = read_cov(cov_fname)
@@ -373,9 +389,9 @@ def test_plot_white():
373389
evoked.plot_white(cov, rank={"grad": 8}, time_unit="s", axes=fig.axes[:4])
374390
with pytest.raises(ValueError, match=r"must have shape \(4,\), got \(2,"):
375391
evoked.plot_white(cov, axes=fig.axes[:2])
376-
with pytest.raises(ValueError, match="When not using SSS"):
392+
with pytest.raises(ValueError, match="exceeds the number"):
377393
evoked.plot_white(cov, rank={"meg": 306})
378-
evoked.plot_white([cov, cov], time_unit="s")
394+
evoked.plot_white([cov, cov], rank={"meg": 9}, time_unit="s")
379395
plt.close("all")
380396

381397
fig = plot_evoked_white(evoked, [cov, cov])

mne/viz/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,14 +2070,20 @@ def _triage_rank_sss(info, covs, rank=None, scalings=None):
20702070
'(separate rank values for "mag" or "grad" are '
20712071
"meaningless)."
20722072
)
2073+
meg_combined = True
20732074
elif "meg" in rank:
2074-
raise ValueError(
2075-
"When not using SSS, pass separate rank values "
2076-
'for "mag" and "grad" (do not use "meg").'
2077-
)
2075+
if has_sss:
2076+
start = "SSS has been applied to data"
2077+
else:
2078+
start = "Got a single MEG rank value"
2079+
logger.info("%s. Showing mag and grad whitening jointly.", start)
2080+
meg_combined = True
2081+
else:
2082+
meg_combined = False
2083+
del has_sss
20782084

2079-
picks_list = _picks_by_type(info, meg_combined=has_sss)
2080-
if has_sss:
2085+
picks_list = _picks_by_type(info, meg_combined=meg_combined)
2086+
if meg_combined:
20812087
# reduce ch_used to combined mag grad
20822088
ch_used = list(zip(*picks_list))[0]
20832089
# order pick list by ch_used (required for compat with plot_evoked)
@@ -2121,7 +2127,7 @@ def _triage_rank_sss(info, covs, rank=None, scalings=None):
21212127
this_rank[ch_type] = rank[ch_type]
21222128

21232129
rank_list.append(this_rank)
2124-
return n_ch_used, rank_list, picks_list, has_sss
2130+
return n_ch_used, rank_list, picks_list, meg_combined
21252131

21262132

21272133
def _check_cov(noise_cov, info):

0 commit comments

Comments
 (0)