Skip to content

Commit 41dbdd5

Browse files
fix scaling in Spectrum.plot(amplitude=True, dB=True) (#13036)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent dcd2625 commit 41dbdd5

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

doc/changes/devel/13036.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix plot scaling for :meth:`Spectrum.plot(dB=True, amplitude=True) <mne.time_frequency.Spectrum.plot>`, by `Daniel McCloy`_.

mne/time_frequency/tests/test_spectrum.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,3 +627,29 @@ def test_plot_spectrum_array_with_bads():
627627
spectrum.get_data(exclude=()), spectrum.info, spectrum.freqs
628628
)
629629
spectrum2.plot(spatial_colors=False)
630+
631+
632+
@pytest.mark.parametrize("dB", (False, True))
633+
@pytest.mark.parametrize("amplitude", (False, True))
634+
def test_plot_spectrum_dB(raw_spectrum, dB, amplitude):
635+
"""Test that we properly handle amplitude/power and dB."""
636+
idx = 7
637+
power = 3
638+
freqs = np.linspace(1, 100, 100)
639+
data = np.full((1, freqs.size), np.finfo(float).tiny)
640+
data[0, idx] = power
641+
info = create_info(ch_names=["delta"], sfreq=1000, ch_types="eeg")
642+
psd = SpectrumArray(data=data, info=info, freqs=freqs)
643+
with pytest.warns(RuntimeWarning, match="Channel locations not available"):
644+
fig = psd.plot(dB=dB, amplitude=amplitude)
645+
trace = list(
646+
filter(lambda x: len(x.get_data()[0]) == len(freqs), fig.axes[0].lines)
647+
)[0]
648+
got = trace.get_data()[1][idx]
649+
want = power * 1e12 # scaling for EEG (V → μV), squared
650+
if amplitude:
651+
want = np.sqrt(want)
652+
if dB:
653+
want = (20 if amplitude else 10) * np.log10(want)
654+
655+
assert want == got, f"expected {want}, got {got}"

mne/viz/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2390,14 +2390,16 @@ def _convert_psds(
23902390
np.sqrt(psds, out=psds)
23912391
psds *= scaling
23922392
ylabel = rf"$\mathrm{{{unit}/\sqrt{{Hz}}}}$"
2393+
coef = 20
23932394
else:
23942395
psds *= scaling * scaling
23952396
if "/" in unit:
23962397
unit = f"({unit})"
23972398
ylabel = rf"$\mathrm{{{unit}²/Hz}}$"
2399+
coef = 10
23982400
if dB:
23992401
np.log10(np.maximum(psds, np.finfo(float).tiny), out=psds)
2400-
psds *= 10
2402+
psds *= coef
24012403
ylabel = r"$\mathrm{dB}\ $" + ylabel
24022404
ylabel = "Power (" + ylabel if estimate == "power" else "Amplitude (" + ylabel
24032405
ylabel += ")"

0 commit comments

Comments
 (0)