Skip to content

Commit a0fe321

Browse files
committed
Make torch an optional dependency
1 parent 6761781 commit a0fe321

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ dependencies = [
3232
"pyyaml",
3333
"jax>=0.3.13",
3434
"jaxlib",
35-
"torch",
3635
]
3736
dynamic = [
3837
"version",
@@ -75,6 +74,10 @@ tests = [
7574
"pytest-cov",
7675
"so3",
7776
"pyssht",
77+
"torch",
78+
]
79+
torch = [
80+
"torch",
7881
]
7982

8083
[tool.scikit-build]

s2fft/utils/torch_wrapper.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,32 @@
3737

3838
import jax
3939
import jax.dlpack
40-
import torch
41-
import torch.utils.dlpack
4240
from jax.tree import map as tree_map
4341

42+
try:
43+
import torch
44+
import torch.utils.dlpack
45+
46+
TORCH_AVAILABLE = True
47+
except ImportError:
48+
TORCH_AVAILABLE = False
49+
4450
T = TypeVar("T")
4551
PyTree = dict[Any, "PyTree"] | list["PyTree"] | tuple["PyTree"] | T
4652

4753

54+
def check_torch_available() -> None:
55+
"""Raise an error if Torch is not importable."""
56+
if not TORCH_AVAILABLE:
57+
msg = (
58+
"torch needs to be installed to use torch wrapper functionality but could\n"
59+
"not be imported. Install s2fft with torch extra using:\n"
60+
" pip install s2fft[torch]\n"
61+
"to allow use of torch wrapper functionality."
62+
)
63+
raise RuntimeError(msg)
64+
65+
4866
def jax_array_to_torch_tensor(jax_array: jax.Array) -> torch.Tensor:
4967
"""
5068
Convert from JAX array to Torch tensor via mutual DLPack support.
@@ -138,6 +156,7 @@ def wrap_as_torch_function(
138156
Wrapped function callable from Torch.
139157
140158
"""
159+
check_torch_available()
141160
sig = signature(jax_function)
142161
if differentiable_argnames is None:
143162
differentiable_argnames = tuple(

0 commit comments

Comments
 (0)