diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index ccf79f334..4dcfabaf8 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -34,7 +34,7 @@ jobs: - name: Install package run: | python -m pip install --progress-bar off --upgrade pip setuptools - python -m pip install --progress-bar off .[test] + python -m pip install --progress-bar off ".[test,torch,onnx]" - run: mne_icalabel-sys_info --developer - run: pytest mne_icalabel --cov=mne_icalabel --cov-report=xml --cov-config=pyproject.toml - uses: codecov/codecov-action@v5 @@ -64,7 +64,7 @@ jobs: - name: Install dependencies run: | python -m pip install --progress-bar off --upgrade pip setuptools - python -m pip install --progress-bar off .[test] + python -m pip install --progress-bar off ".[test,torch,onnx]" python -m pip install --progress-bar off --upgrade git+https://github.com/mne-tools/mne-python python -m pip install --progress-bar off --upgrade git+https://github.com/mne-tools/mne-bids python -m pip install --progress-bar off --upgrade --pre --only-binary :all: -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple --timeout=180 numpy scipy matplotlib @@ -77,3 +77,35 @@ jobs: name: codecov-umbrella # optional token: ${{ secrets.CODECOV_TOKEN }} verbose: true # optional (default = false) + + pytest-backend: + timeout-minutes: 30 + strategy: + fail-fast: false + matrix: + os: [ubuntu] + backend: ["torch", "onnx"] + python-version: ["3.12"] + name: ${{ matrix.os }} - ${{ matrix.backend }} backend - py${{ matrix.python-version }} + runs-on: ${{ matrix.os }}-latest + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + - uses: pyvista/setup-headless-display-action@main + with: + qt: true + - name: Install package + run: | + python -m pip install --progress-bar off --upgrade pip setuptools + python -m pip install --progress-bar off ".[test,${{ matrix.backend }}]"; + - run: mne_icalabel-sys_info --developer + - run: pytest mne_icalabel --cov=mne_icalabel --cov-report=xml --cov-config=pyproject.toml + - uses: codecov/codecov-action@v5 + with: + files: ./coverage.xml + flags: unittests # optional + name: codecov-umbrella # optional + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true # optional (default = false) diff --git a/doc/install.rst b/doc/install.rst index 9739e24ca..9894dfde4 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -43,6 +43,12 @@ Methods pip install torch pip install onnxruntime + + .. note:: + + If you are working with MEG data and plan to use the MEGnet model, e.g. + :func:`mne_icalabel.megnet.megnet_label_components`, you *must* install + ``onnxruntime``, and do not need to install ``torch``. Additional dependencies can be installed with different keywords: diff --git a/mne_icalabel/iclabel/network/tests/test_network.py b/mne_icalabel/iclabel/network/tests/test_network.py index 68ff49efb..5666015cc 100644 --- a/mne_icalabel/iclabel/network/tests/test_network.py +++ b/mne_icalabel/iclabel/network/tests/test_network.py @@ -2,13 +2,14 @@ import numpy as np import pytest -import torch from scipy.io import loadmat from mne_icalabel.datasets import icalabel from mne_icalabel.iclabel.network.utils import _format_input from mne_icalabel.utils._tests import requires_module +torch = pytest.importorskip("torch") + dataset_path = icalabel.data_path() / "iclabel" diff --git a/mne_icalabel/megnet/label_components.py b/mne_icalabel/megnet/label_components.py index 0be37cc5f..27355f06b 100644 --- a/mne_icalabel/megnet/label_components.py +++ b/mne_icalabel/megnet/label_components.py @@ -4,13 +4,14 @@ from typing import TYPE_CHECKING import numpy as np -import onnxruntime as ort from mne.io import BaseRaw from mne.preprocessing import ICA +from ..utils._imports import import_optional_dependency from .features import get_megnet_features if TYPE_CHECKING: + import onnxruntime as ort from numpy.typing import NDArray _MODEL_PATH: str = files("mne_icalabel.megnet") / "assets" / "megnet.onnx" @@ -43,6 +44,7 @@ def megnet_label_components(raw: BaseRaw, ica: ICA) -> NDArray: ---------- .. footbibliography:: """ + ort = import_optional_dependency("onnxruntime") time_series, topomaps = get_megnet_features(raw, ica) # sanity-checks diff --git a/mne_icalabel/megnet/tests/test_label_components.py b/mne_icalabel/megnet/tests/test_label_components.py index f47f15ac6..92b191b80 100644 --- a/mne_icalabel/megnet/tests/test_label_components.py +++ b/mne_icalabel/megnet/tests/test_label_components.py @@ -5,7 +5,6 @@ import mne import numpy as np -import onnxruntime as ort import pytest from mne.io.base import BaseRaw from mne.preprocessing.ica import ICA @@ -17,6 +16,8 @@ megnet_label_components, ) +ort = pytest.importorskip("onnxruntime") + if TYPE_CHECKING: from mne.io import BaseRaw from mne.preprocessing import ICA diff --git a/pyproject.toml b/pyproject.toml index bb639ab18..8758af0c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,8 +113,6 @@ test = [ 'mne-bids>=0.14', 'mne-icalabel[gui]', 'mne-icalabel[ica]', - 'mne-icalabel[onnx]', - 'mne-icalabel[torch]', 'pandas', 'pymatreader', 'PyQt6',