File tree Expand file tree Collapse file tree 2 files changed +25
-3
lines changed
Expand file tree Collapse file tree 2 files changed +25
-3
lines changed Original file line number Diff line number Diff line change @@ -32,7 +32,6 @@ dependencies = [
3232 " pyyaml" ,
3333 " jax>=0.3.13" ,
3434 " jaxlib" ,
35- " torch" ,
3635]
3736dynamic = [
3837 " version" ,
@@ -75,6 +74,10 @@ tests = [
7574 " pytest-cov" ,
7675 " so3" ,
7776 " pyssht" ,
77+ " torch" ,
78+ ]
79+ torch = [
80+ " torch" ,
7881]
7982
8083[tool .scikit-build ]
Original file line number Diff line number Diff line change 3737
3838import jax
3939import jax .dlpack
40- import torch
41- import torch .utils .dlpack
4240from jax .tree import map as tree_map
4341
42+ try :
43+ import torch
44+ import torch .utils .dlpack
45+
46+ TORCH_AVAILABLE = True
47+ except ImportError :
48+ TORCH_AVAILABLE = False
49+
4450T = TypeVar ("T" )
4551PyTree = dict [Any , "PyTree" ] | list ["PyTree" ] | tuple ["PyTree" ] | T
4652
4753
54+ def check_torch_available () -> None :
55+ """Raise an error if Torch is not importable."""
56+ if not TORCH_AVAILABLE :
57+ msg = (
58+ "torch needs to be installed to use torch wrapper functionality but could\n "
59+ "not be imported. Install s2fft with torch extra using:\n "
60+ " pip install s2fft[torch]\n "
61+ "to allow use of torch wrapper functionality."
62+ )
63+ raise RuntimeError (msg )
64+
65+
4866def jax_array_to_torch_tensor (jax_array : jax .Array ) -> torch .Tensor :
4967 """
5068 Convert from JAX array to Torch tensor via mutual DLPack support.
@@ -138,6 +156,7 @@ def wrap_as_torch_function(
138156 Wrapped function callable from Torch.
139157
140158 """
159+ check_torch_available ()
141160 sig = signature (jax_function )
142161 if differentiable_argnames is None :
143162 differentiable_argnames = tuple (
You can’t perform that action at this time.
0 commit comments