Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,11 +1821,18 @@ def plot_evoked_joint(
----------
evoked : instance of Evoked
The evoked instance.
times : float | array of float | "auto" | "peaks"
times : float | array of float | "auto" | "peaks" | dict
The time point(s) to plot. If ``"auto"``, 5 evenly spaced topographies
between the first and last time instant will be shown. If ``"peaks"``,
finds time points automatically by checking for 3 local maxima in
Global Field Power. Defaults to ``"peaks"``.
If a dict, the key must be ``"peaks"`` and the value
can be :
* an int (number of peaks to find over the whole epoch)
* a list of tuples (time windows to find one peak in each)

Defaults to ``"peaks"``. If you want to use evenly spaced time points in
an interval, use :func:`numpy.linspace`.
title : str | None
The title. If ``None``, suppress printing channel type title. If an
empty string, a default title is created. Defaults to ''. If custom
Expand Down Expand Up @@ -1890,6 +1897,20 @@ def plot_evoked_joint(

if times in (None, "peaks"):
n_topomaps = 3 + 1
elif isinstance(times, dict) and "peaks" in times:
if len(times) != 1:
raise ValueError(
"If 'times' is a dict, it must have only one key 'peaks'."
)
val = times["peaks"]
if isinstance(val, int):
n_topomaps = val + 1
elif isinstance(val, (list, tuple)):
n_topomaps = len(val) + 1
else:
raise ValueError(
"Values for 'peaks' must be int or list/tuple of tuples."
)
else:
assert not isinstance(times, str)
n_topomaps = len(times) + 1
Expand Down
31 changes: 31 additions & 0 deletions mne/viz/tests/test_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,34 @@ def get_axes_midpoints(axes):
)
midpoints_after = get_axes_midpoints(topo_axes)
assert (np.linalg.norm(midpoints_before - midpoints_after) < 0.1).all()


def test_plot_joint_times_dict():
"""Test using a dictionary for the 'times' parameter in plot_joint."""
ch_names = ["F3", "Fz", "F4"]
sfreq = 1000.0
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")

data = np.zeros((3, 500))

data[:, 100] = 5e-6
data[:, 300] = 5e-6

evoked = mne.EvokedArray(data, info, tmin=0)

evoked.set_montage(mne.channels.make_standard_montage("standard_1020"))

# Test 1: Integer count (Logic check: does it find N peaks?)
fig = evoked.plot_joint(times={"peaks": 3}, show=False)
assert len(fig.axes) >= 2

# Test 2: Specific Windows (Logic check: does it parse windows?)
fig2 = evoked.plot_joint(times={"peaks": [(0.0, 0.2), (0.2, 0.4)]}, show=False)
assert len(fig2.axes) >= 2

# Test 3: Validation checks (Ensuring robust error handling)
with pytest.raises(ValueError, match="must be 'peaks'"):
evoked.plot_joint(times={"bad_key": 5})

with pytest.raises(ValueError, match="Values for 'peaks' must be"):
evoked.plot_joint(times={"peaks": "invalid_string"})
23 changes: 23 additions & 0 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,29 @@ def _find_peaks(evoked, npeaks):

def _process_times(inst, use_times, n_peaks=None, few=False):
"""Return a list of times for topomaps."""
if isinstance(use_times, dict):
if "peaks" not in use_times:
raise ValueError("If 'times' is a dict, the key must be 'peaks'.")

peak_params = use_times["peaks"]

if isinstance(peak_params, int):
use_times = _find_peaks(inst, peak_params)
elif isinstance(peak_params, (list, tuple)):
peaks = []
for window in peak_params:
if len(window) != 2:
raise ValueError(f"Each window must be (tmin, tmax), got {window}")
tmin, tmax = window
_, t_peak = inst.get_peak(tmin=tmin, tmax=tmax, mode="abs")
peaks.append(t_peak)
use_times = np.array(peaks)
else:
raise ValueError(
"Values for 'peaks' must be an integer or a list/tuple of "
" (tmin, tmax) windows."
)

if isinstance(use_times, str):
if use_times == "interactive":
use_times, n_peaks = "peaks", 1
Expand Down