diff --git a/backend/find_pytorch.py b/backend/find_pytorch.py index 11a967b305..df52e63219 100644 --- a/backend/find_pytorch.py +++ b/backend/find_pytorch.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import importlib import os -import platform import site from functools import ( lru_cache, @@ -30,6 +29,10 @@ Version, ) +from .utils import ( + read_dependencies_from_dependency_group, +) + @lru_cache def find_pytorch() -> tuple[Optional[str], list[str]]: @@ -108,15 +111,15 @@ def get_pt_requirement(pt_version: str = "") -> dict: """ if pt_version is None: return {"torch": []} - if ( - os.environ.get("CIBUILDWHEEL", "0") == "1" - and platform.system() == "Linux" - and platform.machine() == "x86_64" - ): + cibw_requirement = [] + if os.environ.get("CIBUILDWHEEL", "0") == "1": cuda_version = os.environ.get("CUDA_VERSION", "12.2") if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"): # CUDA 12.2, cudnn 9 - pt_version = "2.8.0" + # or CPU builds + cibw_requirement = read_dependencies_from_dependency_group( + "pin_pytorch_cpu" + ) elif cuda_version in SpecifierSet(">=11,<12"): # CUDA 11.8, cudnn 8 pt_version = "2.3.1" @@ -141,6 +144,7 @@ def get_pt_requirement(pt_version: str = "") -> dict: # https://github.com/pytorch/pytorch/commit/7e0c26d4d80d6602aed95cb680dfc09c9ce533bc else "torch>=2.1.0", *mpi_requirement, + *cibw_requirement, ], } diff --git a/backend/find_tensorflow.py b/backend/find_tensorflow.py index a0a1e65aca..457e7a726c 100644 --- a/backend/find_tensorflow.py +++ b/backend/find_tensorflow.py @@ -26,6 +26,10 @@ SpecifierSet, ) +from .utils import ( + read_dependencies_from_dependency_group, +) + @lru_cache def find_tensorflow() -> tuple[Optional[str], list[str]]: @@ -91,10 +95,9 @@ def find_tensorflow() -> tuple[Optional[str], list[str]]: cuda_version = os.environ.get("CUDA_VERSION", "12.2") if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"): # CUDA 12.2, cudnn 9 + # or CPU builds requires.extend( - [ - "tensorflow-cpu>=2.18.0; platform_machine=='x86_64' and platform_system == 'Linux'", - ] + read_dependencies_from_dependency_group("pin_tensorflow_cpu") ) elif cuda_version in SpecifierSet(">=11,<12"): # CUDA 11.8, cudnn 8 diff --git a/backend/utils.py b/backend/utils.py new file mode 100644 index 0000000000..0769879d24 --- /dev/null +++ b/backend/utils.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import sys +from pathlib import ( + Path, +) + +from dependency_groups import ( + resolve, +) + +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib + + +def read_dependencies_from_dependency_group(group: str) -> tuple[str, ...]: + """ + Reads dependencies from a dependency group. + + Parameters + ---------- + group : str + The name of the dependency group. + + Returns + ------- + tuple[str, ...] + A tuple of dependencies in the specified group. + """ + with Path("pyproject.toml").open("rb") as f: + pyproject = tomllib.load(f) + + groups = pyproject["dependency-groups"] + + return resolve(groups, group) diff --git a/pyproject.toml b/pyproject.toml index 751e6f2f7d..13bb96e18c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ requires = [ "scikit-build-core>=0.5,<0.11,!=0.6.0", "packaging", 'tomli >= 1.1.0 ; python_version < "3.11"', + "dependency_groups", ] build-backend = "backend.dp_backend" backend-path = ["."] @@ -159,13 +160,20 @@ dev = [ "mpich", ] pin_tensorflow_cpu = [ - "tensorflow-cpu~=2.18.0", + # https://github.com/tensorflow/tensorflow/issues/75279 + # macos x86 has been deprecated + "tensorflow-cpu~=2.18.0; platform_machine=='x86_64' and platform_system == 'Linux'", + "tensorflow~=2.18.0; (platform_machine!='x86_64' or platform_system != 'Linux') and (platform_machine!='x86_64' or platform_system != 'Darwin')", + "tensorflow; platform_machine=='x86_64' and platform_system == 'Darwin'", ] pin_tensorflow_gpu = [ "tensorflow~=2.18.0", ] pin_pytorch_cpu = [ - "torch~=2.8.0", + # https://github.com/pytorch/pytorch/issues/114602 + # macos x86 has been deprecated + "torch~=2.8.0; platform_machine!='x86_64' or platform_system != 'Darwin'", + "torch; platform_machine=='x86_64' and platform_system == 'Darwin'", ] pin_pytorch_gpu = [ "torch~=2.7.0",