Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions backend/find_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
import os
import platform
import site
from functools import (
lru_cache,
Expand Down Expand Up @@ -30,6 +29,10 @@
Version,
)

from .utils import (
read_dependencies_from_dependency_group,
)


@lru_cache
def find_pytorch() -> tuple[Optional[str], list[str]]:
Expand Down Expand Up @@ -108,15 +111,15 @@ def get_pt_requirement(pt_version: str = "") -> dict:
"""
if pt_version is None:
return {"torch": []}
if (
os.environ.get("CIBUILDWHEEL", "0") == "1"
and platform.system() == "Linux"
and platform.machine() == "x86_64"
):
cibw_requirement = []
if os.environ.get("CIBUILDWHEEL", "0") == "1":
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
# CUDA 12.2, cudnn 9
pt_version = "2.8.0"
# or CPU builds
cibw_requirement = read_dependencies_from_dependency_group(
"pin_pytorch_cpu"
)
elif cuda_version in SpecifierSet(">=11,<12"):
# CUDA 11.8, cudnn 8
pt_version = "2.3.1"
Expand All @@ -141,6 +144,7 @@ def get_pt_requirement(pt_version: str = "") -> dict:
# https://github.com/pytorch/pytorch/commit/7e0c26d4d80d6602aed95cb680dfc09c9ce533bc
else "torch>=2.1.0",
*mpi_requirement,
*cibw_requirement,
],
}

Expand Down
9 changes: 6 additions & 3 deletions backend/find_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
SpecifierSet,
)

from .utils import (
read_dependencies_from_dependency_group,
)


@lru_cache
def find_tensorflow() -> tuple[Optional[str], list[str]]:
Expand Down Expand Up @@ -91,10 +95,9 @@ def find_tensorflow() -> tuple[Optional[str], list[str]]:
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
# CUDA 12.2, cudnn 9
# or CPU builds
requires.extend(
[
"tensorflow-cpu>=2.18.0; platform_machine=='x86_64' and platform_system == 'Linux'",
]
read_dependencies_from_dependency_group("pin_tensorflow_cpu")
)
elif cuda_version in SpecifierSet(">=11,<12"):
# CUDA 11.8, cudnn 8
Expand Down
36 changes: 36 additions & 0 deletions backend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import sys
from pathlib import (
Path,
)

from dependency_groups import (
resolve,
)

if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib


def read_dependencies_from_dependency_group(group: str) -> tuple[str, ...]:
"""
Reads dependencies from a dependency group.

Parameters
----------
group : str
The name of the dependency group.

Returns
-------
tuple[str, ...]
A tuple of dependencies in the specified group.
"""
with Path("pyproject.toml").open("rb") as f:
pyproject = tomllib.load(f)

groups = pyproject["dependency-groups"]

return resolve(groups, group)
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ requires = [
"scikit-build-core>=0.5,<0.11,!=0.6.0",
"packaging",
'tomli >= 1.1.0 ; python_version < "3.11"',
"dependency_groups",
]
build-backend = "backend.dp_backend"
backend-path = ["."]
Expand Down Expand Up @@ -159,13 +160,20 @@ dev = [
"mpich",
]
pin_tensorflow_cpu = [
"tensorflow-cpu~=2.18.0",
# https://github.com/tensorflow/tensorflow/issues/75279
# macos x86 has been deprecated
"tensorflow-cpu~=2.18.0; platform_machine=='x86_64' and platform_system == 'Linux'",
"tensorflow~=2.18.0; (platform_machine!='x86_64' or platform_system != 'Linux') and (platform_machine!='x86_64' or platform_system != 'Darwin')",
"tensorflow; platform_machine=='x86_64' and platform_system == 'Darwin'",
]
pin_tensorflow_gpu = [
"tensorflow~=2.18.0",
]
pin_pytorch_cpu = [
"torch~=2.8.0",
# https://github.com/pytorch/pytorch/issues/114602
# macos x86 has been deprecated
"torch~=2.8.0; platform_machine!='x86_64' or platform_system != 'Darwin'",
"torch; platform_machine=='x86_64' and platform_system == 'Darwin'",
]
pin_pytorch_gpu = [
"torch~=2.7.0",
Expand Down