Skip to content
Merged
7 changes: 4 additions & 3 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ jobs:
matrix:
os: [ubuntu, macos, windows]
python-version: ["3.10", "3.11", "3.12", "3.13"]
name: ${{ matrix.os }} - py${{ matrix.python-version }}
kind: ["torch", "onnx"]
name: ${{ matrix.os }} - py${{ matrix.python-version }} - ${{ matrix.kind }}
runs-on: ${{ matrix.os }}-latest
defaults:
run:
Expand All @@ -34,7 +35,7 @@ jobs:
- name: Install package
run: |
python -m pip install --progress-bar off --upgrade pip setuptools
python -m pip install --progress-bar off .[test]
python -m pip install --progress-bar off ".[test,${{ matrix.kind }}]";
- run: mne_icalabel-sys_info --developer
- run: pytest mne_icalabel --cov=mne_icalabel --cov-report=xml --cov-config=pyproject.toml
- uses: codecov/codecov-action@v5
Expand Down Expand Up @@ -64,7 +65,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --progress-bar off --upgrade pip setuptools
python -m pip install --progress-bar off .[test]
python -m pip install --progress-bar off ".[test,torch,onnx]"
python -m pip install --progress-bar off --upgrade git+https://github.com/mne-tools/mne-python
python -m pip install --progress-bar off --upgrade git+https://github.com/mne-tools/mne-bids
python -m pip install --progress-bar off --upgrade --pre --only-binary :all: -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple --timeout=180 numpy scipy matplotlib
Expand Down
3 changes: 2 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@
"pandas": ("https://pandas.pydata.org/pandas-docs/dev", None),
"pooch": ("https://www.fatiando.org/pooch/latest/", None),
"python": ("https://docs.python.org/3", None),
"scipy": ("https://docs.scipy.org/doc/scipy", None),
# Can use stable after https://github.com/scipy/scipy/issues/23757 is addressed
"scipy": ("https://scipy.github.io/devdocs/", None),
"sklearn": ("https://scikit-learn.org/stable", None),
"torch": ("https://pytorch.org/docs/stable", None),
}
Expand Down
6 changes: 6 additions & 0 deletions doc/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ Methods

pip install torch
pip install onnxruntime

.. note::

If you are working with MEG data and plan to use the MEGnet model, e.g.
:func:`mne_icalabel.megnet.megnet_label_components`, you *must* install
onnxruntime, and do not need to install torch.

Additional dependencies can be installed with different keywords:

Expand Down
3 changes: 2 additions & 1 deletion mne_icalabel/iclabel/network/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import numpy as np
import pytest
import torch
from scipy.io import loadmat

from mne_icalabel.datasets import icalabel
from mne_icalabel.iclabel.network.utils import _format_input
from mne_icalabel.utils._tests import requires_module

torch = pytest.importorskip("torch")

dataset_path = icalabel.data_path() / "iclabel"


Expand Down
4 changes: 3 additions & 1 deletion mne_icalabel/megnet/label_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from typing import TYPE_CHECKING

import numpy as np
import onnxruntime as ort
from mne.io import BaseRaw
from mne.preprocessing import ICA

from ..utils._imports import import_optional_dependency
from .features import get_megnet_features

if TYPE_CHECKING:
import onnxruntime as ort
from numpy.typing import NDArray

_MODEL_PATH: str = files("mne_icalabel.megnet") / "assets" / "megnet.onnx"
Expand Down Expand Up @@ -43,6 +44,7 @@ def megnet_label_components(raw: BaseRaw, ica: ICA) -> NDArray:
----------
.. footbibliography::
"""
ort = import_optional_dependency("onnxruntime")
time_series, topomaps = get_megnet_features(raw, ica)

# sanity-checks
Expand Down
3 changes: 2 additions & 1 deletion mne_icalabel/megnet/tests/test_label_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import mne
import numpy as np
import onnxruntime as ort
import pytest
from mne.io.base import BaseRaw
from mne.preprocessing.ica import ICA
Expand All @@ -17,6 +16,8 @@
megnet_label_components,
)

ort = pytest.importorskip("onnxruntime")

if TYPE_CHECKING:
from mne.io import BaseRaw
from mne.preprocessing import ICA
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ test = [
'mne-bids>=0.14',
'mne-icalabel[gui]',
'mne-icalabel[ica]',
'mne-icalabel[onnx]',
'mne-icalabel[torch]',
'pandas',
'pymatreader',
'PyQt6',
Expand Down