diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index c12c1f0945e..fe6ecc6b0c7 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1821,11 +1821,18 @@ def plot_evoked_joint( ---------- evoked : instance of Evoked The evoked instance. - times : float | array of float | "auto" | "peaks" + times : float | array of float | "auto" | "peaks" | dict The time point(s) to plot. If ``"auto"``, 5 evenly spaced topographies between the first and last time instant will be shown. If ``"peaks"``, finds time points automatically by checking for 3 local maxima in Global Field Power. Defaults to ``"peaks"``. + If a dict, the key must be ``"peaks"`` and the value + can be : + * an int (number of peaks to find over the whole epoch) + * a list of tuples (time windows to find one peak in each) + + Defaults to ``"peaks"``. If you want to use evenly spaced time points in + an interval, use :func:`numpy.linspace`. title : str | None The title. If ``None``, suppress printing channel type title. If an empty string, a default title is created. Defaults to ''. If custom @@ -1890,6 +1897,20 @@ def plot_evoked_joint( if times in (None, "peaks"): n_topomaps = 3 + 1 + elif isinstance(times, dict) and "peaks" in times: + if len(times) != 1: + raise ValueError( + "If 'times' is a dict, it must have only one key 'peaks'." + ) + val = times["peaks"] + if isinstance(val, int): + n_topomaps = val + 1 + elif isinstance(val, (list, tuple)): + n_topomaps = len(val) + 1 + else: + raise ValueError( + "Values for 'peaks' must be int or list/tuple of tuples." + ) else: assert not isinstance(times, str) n_topomaps = len(times) + 1 diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index 964acae2b31..0484d791a19 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -665,3 +665,34 @@ def get_axes_midpoints(axes): ) midpoints_after = get_axes_midpoints(topo_axes) assert (np.linalg.norm(midpoints_before - midpoints_after) < 0.1).all() + + +def test_plot_joint_times_dict(): + """Test using a dictionary for the 'times' parameter in plot_joint.""" + ch_names = ["F3", "Fz", "F4"] + sfreq = 1000.0 + info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg") + + data = np.zeros((3, 500)) + + data[:, 100] = 5e-6 + data[:, 300] = 5e-6 + + evoked = mne.EvokedArray(data, info, tmin=0) + + evoked.set_montage(mne.channels.make_standard_montage("standard_1020")) + + # Test 1: Integer count (Logic check: does it find N peaks?) + fig = evoked.plot_joint(times={"peaks": 3}, show=False) + assert len(fig.axes) >= 2 + + # Test 2: Specific Windows (Logic check: does it parse windows?) + fig2 = evoked.plot_joint(times={"peaks": [(0.0, 0.2), (0.2, 0.4)]}, show=False) + assert len(fig2.axes) >= 2 + + # Test 3: Validation checks (Ensuring robust error handling) + with pytest.raises(ValueError, match="must be 'peaks'"): + evoked.plot_joint(times={"bad_key": 5}) + + with pytest.raises(ValueError, match="Values for 'peaks' must be"): + evoked.plot_joint(times={"peaks": "invalid_string"}) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 9c71714040a..04e344486b3 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -895,6 +895,29 @@ def _find_peaks(evoked, npeaks): def _process_times(inst, use_times, n_peaks=None, few=False): """Return a list of times for topomaps.""" + if isinstance(use_times, dict): + if "peaks" not in use_times: + raise ValueError("If 'times' is a dict, the key must be 'peaks'.") + + peak_params = use_times["peaks"] + + if isinstance(peak_params, int): + use_times = _find_peaks(inst, peak_params) + elif isinstance(peak_params, (list, tuple)): + peaks = [] + for window in peak_params: + if len(window) != 2: + raise ValueError(f"Each window must be (tmin, tmax), got {window}") + tmin, tmax = window + _, t_peak = inst.get_peak(tmin=tmin, tmax=tmax, mode="abs") + peaks.append(t_peak) + use_times = np.array(peaks) + else: + raise ValueError( + "Values for 'peaks' must be an integer or a list/tuple of " + " (tmin, tmax) windows." + ) + if isinstance(use_times, str): if use_times == "interactive": use_times, n_peaks = "peaks", 1