Skip to content

Commit bb8d760

Browse files
larsonerautofix-ci[bot]drammock
authored
ENH: Speed up forward computations for iterative fitting (#13407)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy <[email protected]>
1 parent 07a63fd commit bb8d760

File tree

22 files changed

+355
-151
lines changed

22 files changed

+355
-151
lines changed

doc/changes/dev/13407.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug with :func:`mne.make_forward_solution` where sources were not checked to make sure they're inside the inner skull for spherical BEMs, by `Eric Larson`_.

examples/simulation/simulate_raw_data.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
simulate_sparse_stc,
3232
)
3333

34-
print(__doc__)
35-
3634
data_path = sample.data_path()
3735
meg_path = data_path / "MEG" / "sample"
3836
raw_fname = meg_path / "sample_audvis_raw.fif"

examples/simulation/simulated_raw_data_using_subject_anatomy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
import mne
2727
from mne.datasets import sample
2828

29-
print(__doc__)
30-
3129
# %%
3230
# In this example, raw data will be simulated for the sample subject, so its
3331
# information needs to be loaded. This step will download the data if it not

mne/_ola.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __init__(
309309
# Create our window boundaries
310310
window_name = window if isinstance(window, str) else "custom"
311311
self._window = get_window(
312-
window, self._n_samples, fftbins=(self._n_samples - 1) % 2
312+
window, self._n_samples, fftbins=bool((self._n_samples - 1) % 2)
313313
)
314314
self._window /= _check_cola(
315315
self._window, self._n_samples, self._step, window_name, tol=tol

mne/chpi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
pick_types,
3939
)
4040
from ._fiff.proj import Projection, setup_proj
41+
from .bem import ConductorModel
4142
from .channels.channels import _get_meg_system
4243
from .cov import compute_whitener, make_ad_hoc_cov
4344
from .dipole import _make_guesses
@@ -1343,9 +1344,8 @@ def compute_chpi_locs(
13431344

13441345
# Make some location guesses (1 cm grid)
13451346
R = np.linalg.norm(meg_coils[0], axis=1).min()
1346-
guesses = _make_guesses(
1347-
dict(R=R, r0=np.zeros(3)), 0.01, 0.0, 0.005, verbose=safe_false
1348-
)[0]["rr"]
1347+
sphere = ConductorModel(layers=[dict(rad=R)], r0=np.zeros(3), is_sphere=True)
1348+
guesses = _make_guesses(sphere, 0.01, 0.0, 0.005, verbose=safe_false)[0]["rr"]
13491349
logger.info(
13501350
f"Computing {len(guesses)} HPI location guesses "
13511351
f"(1 cm grid in a {R * 100:.1f} cm sphere)"

mne/conftest.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -729,14 +729,16 @@ def _evoked_cov_sphere(_evoked):
729729
evoked.pick(evoked.ch_names[::4])
730730
assert len(evoked.ch_names) == 77
731731
cov = mne.read_cov(fname_cov)
732-
sphere = mne.make_sphere_model("auto", "auto", evoked.info)
732+
sphere = mne.make_sphere_model(
733+
(0.0, 0.0, 0.04), 0.1, relative_radii=(0.995, 0.997, 0.998, 1.0)
734+
)
733735
return evoked, cov, sphere
734736

735737

736738
@pytest.fixture(scope="session")
737739
def _fwd_surf(_evoked_cov_sphere):
738740
"""Compute a forward for a surface source space."""
739-
evoked, cov, sphere = _evoked_cov_sphere
741+
evoked, _, sphere = _evoked_cov_sphere
740742
src_surf = mne.read_source_spaces(fname_src)
741743
return mne.make_forward_solution(
742744
evoked.info, fname_trans, src_surf, sphere, mindist=5.0
@@ -747,7 +749,7 @@ def _fwd_surf(_evoked_cov_sphere):
747749
def _fwd_subvolume(_evoked_cov_sphere):
748750
"""Compute a forward for a surface source space."""
749751
pytest.importorskip("nibabel")
750-
evoked, cov, sphere = _evoked_cov_sphere
752+
evoked, _, sphere = _evoked_cov_sphere
751753
volume_labels = ["Left-Cerebellum-Cortex", "right-Cerebellum-Cortex"]
752754
with pytest.raises(ValueError, match=r"Did you mean one of \['Right-Cere"):
753755
mne.setup_volume_source_space(
@@ -761,9 +763,12 @@ def _fwd_subvolume(_evoked_cov_sphere):
761763
subjects_dir=subjects_dir,
762764
add_interpolator=False,
763765
)
764-
return mne.make_forward_solution(
765-
evoked.info, fname_trans, src_vol, sphere, mindist=5.0
766+
fwd = mne.make_forward_solution(
767+
evoked.info, fname_trans, src_vol, sphere, mindist=1.0
766768
)
769+
nsrc = sum(s["nuse"] for s in src_vol)
770+
assert fwd["nsource"] == nsrc
771+
return fwd
767772

768773

769774
@pytest.fixture

mne/dipole.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ._fiff.pick import pick_types
1818
from ._fiff.proj import _needs_eeg_average_ref_proj, make_projector
1919
from ._freesurfer import _get_aseg, head_to_mni, head_to_mri, read_freesurfer_lut
20-
from .bem import _bem_find_surface, _bem_surf_name, _fit_sphere
20+
from .bem import ConductorModel, _bem_find_surface, _bem_surf_name, _fit_sphere
2121
from .cov import _ensure_cov, compute_whitener
2222
from .evoked import _aspect_rev, _read_evoked, _write_evokeds
2323
from .fixes import _safe_svd
@@ -960,9 +960,8 @@ def _make_guesses(surf, grid, exclude, mindist, n_jobs=None, verbose=None):
960960
)
961961
else:
962962
logger.info(
963-
"Making a spherical guess space with radius {:7.1f} mm...".format(
964-
1000 * surf["R"]
965-
)
963+
f"Making a spherical guess space with radius {1000 * surf.radius:7.1f} "
964+
"mm..."
966965
)
967966
logger.info("Filtering (grid = %6.f mm)..." % (1000 * grid))
968967
src = _make_volume_source_space(
@@ -1317,7 +1316,7 @@ def _fit_dipole(
13171316
constraint = partial(
13181317
_sphere_constraint,
13191318
r0=fwd_data["inner_skull"]["r0"],
1320-
R_adj=fwd_data["inner_skull"]["R"] - min_dist_to_inner_skull,
1319+
R_adj=fwd_data["inner_skull"].radius - min_dist_to_inner_skull,
13211320
)
13221321

13231322
# Find a good starting point (find_best_guess in C)
@@ -1600,13 +1599,14 @@ def fit_dipole(
16001599
raise RuntimeError(
16011600
"No MEG channels found, but MEG-only sphere model used"
16021601
)
1603-
R = np.min(np.sqrt(np.sum(R * R, axis=1))) # use dist to sensors
1604-
kind = "max_rad"
1602+
R = np.min(np.linalg.norm(R, axis=1))
1603+
kind = "min_rad"
16051604
logger.info(
16061605
f"Sphere model : origin at ({1000 * r0[0]: 7.2f} {1000 * r0[1]: 7.2f} "
16071606
f"{1000 * r0[2]: 7.2f}) mm, {kind} = {R:6.1f} mm"
16081607
)
1609-
inner_skull = dict(R=R, r0=r0) # NB sphere model defined in head frame
1608+
# NB sphere model defined in head frame
1609+
inner_skull = ConductorModel(layers=[dict(rad=R)], r0=r0, is_sphere=True)
16101610
del R, r0
16111611

16121612
# Deal with DipoleFixed cases here
@@ -1710,7 +1710,9 @@ def fit_dipole(
17101710
check = _surface_constraint(pos, inner_skull, min_dist_to_inner_skull)
17111711
else:
17121712
check = _sphere_constraint(
1713-
pos, inner_skull["r0"], R_adj=inner_skull["R"] - min_dist_to_inner_skull
1713+
pos,
1714+
inner_skull["r0"],
1715+
R_adj=inner_skull.radius - min_dist_to_inner_skull,
17141716
)
17151717
if check <= 0:
17161718
raise ValueError(
@@ -1720,7 +1722,7 @@ def fit_dipole(
17201722

17211723
# C code computes guesses w/sphere model for speed, don't bother here
17221724
fwd_data = _prep_field_computation(
1723-
guess_src["rr"], sensors=sensors, bem=bem, n_jobs=n_jobs, verbose=safe_false
1725+
sensors=sensors, bem=bem, n_jobs=n_jobs, verbose=safe_false
17241726
)
17251727
fwd_data["inner_skull"] = inner_skull
17261728
guess_fwd, guess_fwd_orig, guess_fwd_scales = _dipole_forwards(

mne/forward/_compute_forward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def _compute_mdfv(rrs, rmags, cosmags, ws, bins, too_close):
709709

710710

711711
@verbose
712-
def _prep_field_computation(rr, *, sensors, bem, n_jobs, verbose=None):
712+
def _prep_field_computation(*, sensors, bem, n_jobs, verbose=None):
713713
"""Precompute and store some things that are used for both MEG and EEG.
714714
715715
Calculation includes multiplication factors, coordinate transforms,
@@ -840,7 +840,7 @@ def _compute_forwards(rr, *, bem, sensors, n_jobs, verbose=None):
840840
# This modifies "sensors" in place, so let's copy it in case the calling
841841
# function needs to reuse it (e.g., in simulate_raw.py)
842842
sensors = deepcopy(sensors)
843-
fwd_data = _prep_field_computation(rr, sensors=sensors, bem=bem, n_jobs=n_jobs)
843+
fwd_data = _prep_field_computation(sensors=sensors, bem=bem, n_jobs=n_jobs)
844844
Bs = _compute_forwards_meeg(
845845
rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs
846846
)

mne/forward/_make_forward.py

Lines changed: 111 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
from ..bem import ConductorModel, _bem_find_surface, read_bem_solution
2222
from ..source_estimate import VolSourceEstimate
2323
from ..source_space._source_space import (
24+
SourceSpaces,
2425
_complete_vol_src,
2526
_ensure_src,
2627
_filter_source_spaces,
2728
_make_discrete_source_space,
2829
)
29-
from ..surface import _CheckInside, _normalize_vectors
30+
from ..surface import _CheckInside, _CheckInsideSphere, _normalize_vectors
3031
from ..transforms import (
3132
Transform,
3233
_coord_frame_name,
@@ -35,10 +36,13 @@
3536
_print_coord_trans,
3637
apply_trans,
3738
invert_transform,
38-
transform_surface_to,
3939
)
4040
from ..utils import _check_fname, _pl, _validate_type, logger, verbose, warn
41-
from ._compute_forward import _compute_forwards
41+
from ._compute_forward import (
42+
_compute_forwards,
43+
_compute_forwards_meeg,
44+
_prep_field_computation,
45+
)
4246
from .forward import _FWD_ORDER, Forward, _merge_fwds, convert_forward_solution
4347

4448
_accuracy_dict = dict(
@@ -459,7 +463,7 @@ def _prepare_for_forward(
459463
# let's make a copy in case we modify something
460464
src = _ensure_src(src).copy()
461465
nsource = sum(s["nuse"] for s in src)
462-
if nsource == 0:
466+
if len(src) and nsource == 0:
463467
raise RuntimeError(
464468
"No sources are active in these source spaces. "
465469
'"do_all" option should be used.'
@@ -517,11 +521,12 @@ def _prepare_for_forward(
517521

518522
# Transform the source spaces into the appropriate coordinates
519523
# (will either be HEAD or MRI)
520-
for s in src:
521-
transform_surface_to(s, "head", mri_head_t)
522-
logger.info(
523-
f"Source spaces are now in {_coord_frame_name(s['coord_frame'])} coordinates."
524-
)
524+
src._transform_to("head", mri_head_t)
525+
if len(src):
526+
logger.info(
527+
f"Source spaces are now in {_coord_frame_name(src[0]['coord_frame'])} "
528+
"coordinates."
529+
)
525530

526531
# Prepare the BEM model
527532
eegnames = sensors.get("eeg", dict()).get("ch_names", [])
@@ -533,48 +538,50 @@ def _prepare_for_forward(
533538
# Circumvent numerical problems by excluding points too close to the skull,
534539
# and check that sensors are not inside any BEM surface
535540
if bem is not None:
541+
kwargs = dict(limit=mindist, mri_head_t=mri_head_t, src=src)
536542
if not bem["is_sphere"]:
537543
check_surface = "inner skull surface"
538-
inner_skull = _bem_find_surface(bem, "inner_skull")
539-
check_inside = _filter_source_spaces(
540-
inner_skull, mindist, mri_head_t, src, n_jobs
541-
)
544+
check_inside_brain = _CheckInside(_bem_find_surface(bem, "inner_skull"))
542545
logger.info("")
543546
if len(bem["surfs"]) == 3:
544547
check_surface = "scalp surface"
545-
check_inside = _CheckInside(_bem_find_surface(bem, "head"))
548+
check_inside_head = _CheckInside(_bem_find_surface(bem, "head"))
549+
else:
550+
check_inside_head = check_inside_brain
546551
else:
547552
check_surface = "outermost sphere shell"
548-
if len(bem["layers"]) == 0:
553+
check_inside_brain = _CheckInsideSphere(bem)
554+
if bem.radius is not None:
555+
check_inside_head = _CheckInsideSphere(bem, check="outer")
556+
else:
549557

550-
def check_inside(x):
558+
def check_inside_head(x):
551559
return np.zeros(len(x), bool)
552560

553-
else:
554-
555-
def check_inside(x):
556-
r0 = apply_trans(invert_transform(mri_head_t), bem["r0"])
557-
return np.linalg.norm(x - r0, axis=1) < bem["layers"][-1]["rad"]
561+
if len(src):
562+
_filter_source_spaces(check_inside_brain, **kwargs)
558563

559564
if "meg" in sensors:
560-
meg_loc = apply_trans(
561-
invert_transform(mri_head_t),
562-
np.array([coil["r0"] for coil in sensors["meg"]["defs"]]),
563-
)
564-
n_inside = check_inside(meg_loc).sum()
565+
meg_loc = np.array([coil["r0"] for coil in sensors["meg"]["defs"]])
566+
if not bem["is_sphere"]:
567+
meg_loc = apply_trans(invert_transform(mri_head_t), meg_loc)
568+
n_inside = check_inside_head(meg_loc).sum()
565569
if n_inside:
566570
raise RuntimeError(
567571
f"Found {n_inside} MEG sensor{_pl(n_inside)} inside the "
568572
f"{check_surface}, perhaps coordinate frames and/or "
569573
"coregistration must be incorrect"
570574
)
571575

572-
rr = np.concatenate([s["rr"][s["vertno"]] for s in src])
573-
if len(rr) < 1:
574-
raise RuntimeError(
575-
"No points left in source space after excluding "
576-
"points close to inner skull."
577-
)
576+
if len(src):
577+
rr = np.concatenate([s["rr"][s["vertno"]] for s in src])
578+
if len(rr) < 1:
579+
raise RuntimeError(
580+
"No points left in source space after excluding "
581+
"points close to inner skull."
582+
)
583+
else:
584+
rr = np.zeros((0, 3))
578585

579586
# deal with free orientations:
580587
source_nn = np.tile(np.eye(3), (len(rr), 1))
@@ -934,3 +941,75 @@ def use_coil_def(fname):
934941
yield
935942
finally:
936943
_extra_coil_def_fname = None
944+
945+
946+
class _ForwardModeler:
947+
"""Optimized incremental fitting using the same sensors and BEM."""
948+
949+
@verbose
950+
def __init__(
951+
self,
952+
info,
953+
trans,
954+
bem,
955+
*,
956+
mindist=0.0,
957+
n_jobs=1,
958+
verbose=None,
959+
):
960+
self.mri_head_t, _ = _get_trans(trans)
961+
self.mindist = mindist
962+
self.n_jobs = n_jobs
963+
src = SourceSpaces([])
964+
self.sensors, _, _, _, self.bem = _prepare_for_forward(
965+
src,
966+
self.mri_head_t,
967+
info,
968+
bem,
969+
mindist,
970+
n_jobs,
971+
bem_extra="",
972+
trans="",
973+
info_extra="",
974+
meg=True,
975+
eeg=True,
976+
ignore_ref=False,
977+
)
978+
self.fwd_data = _prep_field_computation(
979+
sensors=self.sensors,
980+
bem=self.bem,
981+
n_jobs=self.n_jobs,
982+
)
983+
if self.bem["is_sphere"]:
984+
self.check_inside = _CheckInsideSphere(self.bem)
985+
else:
986+
self.check_inside = _CheckInside(_bem_find_surface(self.bem, "inner_skull"))
987+
988+
def compute(self, src):
989+
src = _ensure_src(src).copy()
990+
src._transform_to("head", self.mri_head_t)
991+
kwargs = dict(limit=self.mindist, mri_head_t=self.mri_head_t, src=src)
992+
_filter_source_spaces(self.check_inside, n_jobs=self.n_jobs, **kwargs)
993+
rr = np.concatenate([s["rr"][s["vertno"]] for s in src])
994+
if len(rr) < 1:
995+
raise RuntimeError(
996+
"No points left in source space after excluding "
997+
"points close to inner skull."
998+
)
999+
1000+
sensors = deepcopy(self.sensors)
1001+
fwd_data = deepcopy(self.fwd_data)
1002+
fwds = _compute_forwards_meeg(
1003+
rr,
1004+
sensors=sensors,
1005+
fwd_data=fwd_data,
1006+
n_jobs=self.n_jobs,
1007+
)
1008+
fwds = {
1009+
key: _to_forward_dict(fwds[key], sensors[key]["ch_names"])
1010+
for key in _FWD_ORDER
1011+
if key in fwds
1012+
}
1013+
fwd = _merge_fwds(fwds, verbose=False)
1014+
del fwds
1015+
return fwd

0 commit comments

Comments
 (0)