From fa835bf67f94af8516cf693105029f5b9d2127f5 Mon Sep 17 00:00:00 2001 From: Shruti Bhale Date: Sat, 3 Jan 2026 22:04:52 +0530 Subject: [PATCH 1/2] ENH: Allow dict for times in plot_evoked_joint --- mne/viz/evoked.py | 16 +++++++++++++++- mne/viz/tests/test_evoked.py | 32 ++++++++++++++++++++++++++++++++ mne/viz/utils.py | 21 +++++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index c12c1f0945e..2bf59805d81 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1821,11 +1821,15 @@ def plot_evoked_joint( ---------- evoked : instance of Evoked The evoked instance. - times : float | array of float | "auto" | "peaks" + times : float | array of float | "auto" | "peaks" | dict The time point(s) to plot. If ``"auto"``, 5 evenly spaced topographies between the first and last time instant will be shown. If ``"peaks"``, finds time points automatically by checking for 3 local maxima in Global Field Power. Defaults to ``"peaks"``. + If a dict, the key must be ``"peaks"`` with a value of either an integer + (number of peaks to find) or a list/tuple of tuples (time windows to + find a peak in). If you want to use evenly spaced time points in an + interval, use :func:`numpy.linspace`. title : str | None The title. If ``None``, suppress printing channel type title. If an empty string, a default title is created. Defaults to ''. If custom @@ -1890,6 +1894,16 @@ def plot_evoked_joint( if times in (None, "peaks"): n_topomaps = 3 + 1 + elif isinstance(times, dict) and "peaks" in times: + if len(times) != 1: + raise ValueError("If 'times' is a dict, it must have only one key 'peaks'.") + val = times["peaks"] + if isinstance(val, int): + n_topomaps = val + 1 + elif isinstance(val, (list, tuple)): + n_topomaps = len(val) + 1 + else: + raise ValueError("Values for 'peaks' must be int or list/tuple of tuples.") else: assert not isinstance(times, str) n_topomaps = len(times) + 1 diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index 964acae2b31..b9102637b23 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -665,3 +665,35 @@ def get_axes_midpoints(axes): ) midpoints_after = get_axes_midpoints(topo_axes) assert (np.linalg.norm(midpoints_before - midpoints_after) < 0.1).all() + + +def test_plot_joint_times_dict(): + """Test using a dictionary for the 'times' parameter in plot_joint.""" + + ch_names = ['F3', 'Fz', 'F4'] + sfreq = 1000. + info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') + + data = np.zeros((3, 500)) + + data[:, 100] = 5e-6 + data[:, 300] = 5e-6 + + evoked = mne.EvokedArray(data, info, tmin=0) + + evoked.set_montage(mne.channels.make_standard_montage('standard_1020')) + + # Test 1: Integer count (Logic check: does it find N peaks?) + fig = evoked.plot_joint(times={"peaks": 3}, show=False) + assert len(fig.axes) >= 2 + + # Test 2: Specific Windows (Logic check: does it parse windows?) + fig2 = evoked.plot_joint(times={"peaks": [(0.0, 0.2), (0.2, 0.4)]}, show=False) + assert len(fig2.axes) >= 2 + + # Test 3: Validation checks (Ensuring robust error handling) + with pytest.raises(ValueError, match="must be 'peaks'"): + evoked.plot_joint(times={"bad_key": 5}) + + with pytest.raises(ValueError, match="Values for 'peaks' must be"): + evoked.plot_joint(times={"peaks": "invalid_string"}) \ No newline at end of file diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 9c71714040a..57b180d4dc9 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -895,6 +895,27 @@ def _find_peaks(evoked, npeaks): def _process_times(inst, use_times, n_peaks=None, few=False): """Return a list of times for topomaps.""" + + if isinstance(use_times, dict): + if "peaks" not in use_times: + raise ValueError("If 'times' is a dict, the key must be 'peaks'.") + + peak_params = use_times["peaks"] + + if isinstance(peak_params, int): + use_times = _find_peaks(inst, peak_params) + elif isinstance(peak_params, (list, tuple)): + peaks = [] + for window in peak_params: + if len(window) != 2: + raise ValueError(f"Each window must be (tmin, tmax), got {window}") + tmin, tmax = window + _, t_peak = inst.get_peak(tmin=tmin, tmax=tmax, mode='abs') + peaks.append(t_peak) + use_times = np.array(peaks) + else: + raise ValueError("Values for 'peaks' must be an integer or a list/tuple of (tmin, tmax) windows.") + if isinstance(use_times, str): if use_times == "interactive": use_times, n_peaks = "peaks", 1 From 2c0734b982b4b7b26b8cf30481a7f3ee9f6380e4 Mon Sep 17 00:00:00 2001 From: Shruti Bhale Date: Mon, 5 Jan 2026 21:01:13 +0530 Subject: [PATCH 2/2] MAINT: Fix style and documentation formatting --- mne/viz/evoked.py | 19 +++++++++++++------ mne/viz/tests/test_evoked.py | 29 ++++++++++++++--------------- mne/viz/utils.py | 14 ++++++++------ 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 2bf59805d81..fe6ecc6b0c7 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1826,10 +1826,13 @@ def plot_evoked_joint( between the first and last time instant will be shown. If ``"peaks"``, finds time points automatically by checking for 3 local maxima in Global Field Power. Defaults to ``"peaks"``. - If a dict, the key must be ``"peaks"`` with a value of either an integer - (number of peaks to find) or a list/tuple of tuples (time windows to - find a peak in). If you want to use evenly spaced time points in an - interval, use :func:`numpy.linspace`. + If a dict, the key must be ``"peaks"`` and the value + can be : + * an int (number of peaks to find over the whole epoch) + * a list of tuples (time windows to find one peak in each) + + Defaults to ``"peaks"``. If you want to use evenly spaced time points in + an interval, use :func:`numpy.linspace`. title : str | None The title. If ``None``, suppress printing channel type title. If an empty string, a default title is created. Defaults to ''. If custom @@ -1896,14 +1899,18 @@ def plot_evoked_joint( n_topomaps = 3 + 1 elif isinstance(times, dict) and "peaks" in times: if len(times) != 1: - raise ValueError("If 'times' is a dict, it must have only one key 'peaks'.") + raise ValueError( + "If 'times' is a dict, it must have only one key 'peaks'." + ) val = times["peaks"] if isinstance(val, int): n_topomaps = val + 1 elif isinstance(val, (list, tuple)): n_topomaps = len(val) + 1 else: - raise ValueError("Values for 'peaks' must be int or list/tuple of tuples.") + raise ValueError( + "Values for 'peaks' must be int or list/tuple of tuples." + ) else: assert not isinstance(times, str) n_topomaps = len(times) + 1 diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index b9102637b23..0484d791a19 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -669,31 +669,30 @@ def get_axes_midpoints(axes): def test_plot_joint_times_dict(): """Test using a dictionary for the 'times' parameter in plot_joint.""" - - ch_names = ['F3', 'Fz', 'F4'] - sfreq = 1000. - info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') - + ch_names = ["F3", "Fz", "F4"] + sfreq = 1000.0 + info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg") + data = np.zeros((3, 500)) - - data[:, 100] = 5e-6 - data[:, 300] = 5e-6 - + + data[:, 100] = 5e-6 + data[:, 300] = 5e-6 + evoked = mne.EvokedArray(data, info, tmin=0) - - evoked.set_montage(mne.channels.make_standard_montage('standard_1020')) + + evoked.set_montage(mne.channels.make_standard_montage("standard_1020")) # Test 1: Integer count (Logic check: does it find N peaks?) fig = evoked.plot_joint(times={"peaks": 3}, show=False) - assert len(fig.axes) >= 2 - + assert len(fig.axes) >= 2 + # Test 2: Specific Windows (Logic check: does it parse windows?) fig2 = evoked.plot_joint(times={"peaks": [(0.0, 0.2), (0.2, 0.4)]}, show=False) assert len(fig2.axes) >= 2 - + # Test 3: Validation checks (Ensuring robust error handling) with pytest.raises(ValueError, match="must be 'peaks'"): evoked.plot_joint(times={"bad_key": 5}) with pytest.raises(ValueError, match="Values for 'peaks' must be"): - evoked.plot_joint(times={"peaks": "invalid_string"}) \ No newline at end of file + evoked.plot_joint(times={"peaks": "invalid_string"}) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 57b180d4dc9..04e344486b3 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -895,13 +895,12 @@ def _find_peaks(evoked, npeaks): def _process_times(inst, use_times, n_peaks=None, few=False): """Return a list of times for topomaps.""" - if isinstance(use_times, dict): if "peaks" not in use_times: raise ValueError("If 'times' is a dict, the key must be 'peaks'.") - + peak_params = use_times["peaks"] - + if isinstance(peak_params, int): use_times = _find_peaks(inst, peak_params) elif isinstance(peak_params, (list, tuple)): @@ -910,12 +909,15 @@ def _process_times(inst, use_times, n_peaks=None, few=False): if len(window) != 2: raise ValueError(f"Each window must be (tmin, tmax), got {window}") tmin, tmax = window - _, t_peak = inst.get_peak(tmin=tmin, tmax=tmax, mode='abs') + _, t_peak = inst.get_peak(tmin=tmin, tmax=tmax, mode="abs") peaks.append(t_peak) use_times = np.array(peaks) else: - raise ValueError("Values for 'peaks' must be an integer or a list/tuple of (tmin, tmax) windows.") - + raise ValueError( + "Values for 'peaks' must be an integer or a list/tuple of " + " (tmin, tmax) windows." + ) + if isinstance(use_times, str): if use_times == "interactive": use_times, n_peaks = "peaks", 1