Skip to content

Commit 1c30841

Browse files
committed
Fix torch optional import logic to avoid errors when not installed
1 parent 7208d9f commit 1c30841

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

s2fft/utils/torch_wrapper.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@
4444
try:
4545
import torch
4646
import torch.utils.dlpack
47+
from torch import Tensor
4748

4849
TORCH_AVAILABLE = True
4950
except ImportError:
51+
Tensor = None
5052
TORCH_AVAILABLE = False
5153

5254
T = TypeVar("T")
@@ -65,7 +67,7 @@ def check_torch_available() -> None:
6567
raise RuntimeError(msg)
6668

6769

68-
def jax_array_to_torch_tensor(jax_array: jax.Array) -> torch.Tensor:
70+
def jax_array_to_torch_tensor(jax_array: jax.Array) -> Tensor:
6971
"""
7072
Convert from JAX array to Torch tensor via mutual DLPack support.
7173
@@ -84,7 +86,7 @@ def jax_array_to_torch_tensor(jax_array: jax.Array) -> torch.Tensor:
8486
return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(jax_array))
8587

8688

87-
def torch_tensor_to_jax_array(torch_tensor: torch.Tensor) -> jax.Array:
89+
def torch_tensor_to_jax_array(torch_tensor: Tensor) -> jax.Array:
8890
"""
8991
Convert from Torch tensor to JAX array via mutual DLPack support.
9092
@@ -117,7 +119,7 @@ def torch_tensor_to_jax_array(torch_tensor: torch.Tensor) -> jax.Array:
117119

118120
def tree_map_jax_array_to_torch_tensor(
119121
jax_pytree: PyTree[jax.Array],
120-
) -> PyTree[torch.Tensor]:
122+
) -> PyTree[Tensor]:
121123
"""
122124
Convert from a pytree with JAX arrays to corresponding pytree with Torch tensors.
123125
@@ -135,7 +137,7 @@ def tree_map_jax_array_to_torch_tensor(
135137

136138

137139
def tree_map_torch_tensor_to_jax_array(
138-
torch_pytree: PyTree[torch.Tensor],
140+
torch_pytree: PyTree[Tensor],
139141
) -> PyTree[jax.Array]:
140142
"""
141143
Convert from a pytree with Torch tensors to corresponding pytree with JAX arrays.
@@ -148,7 +150,7 @@ def tree_map_torch_tensor_to_jax_array(
148150
149151
"""
150152
return tree_map(
151-
lambda t: torch_tensor_to_jax_array(t) if isinstance(t, torch.Tensor) else t,
153+
lambda t: torch_tensor_to_jax_array(t) if isinstance(t, Tensor) else t,
152154
torch_pytree,
153155
)
154156

@@ -176,7 +178,6 @@ def wrap_as_torch_function(
176178
Wrapped function callable from Torch.
177179
178180
"""
179-
check_torch_available()
180181
sig = signature(jax_function)
181182
if differentiable_argnames is None:
182183
differentiable_argnames = tuple(
@@ -191,6 +192,7 @@ def wrap_as_torch_function(
191192

192193
@wraps(jax_function)
193194
def torch_function(*args, **kwargs):
195+
check_torch_available()
194196
bound_args = sig.bind(*args, **kwargs)
195197
bound_args.apply_defaults()
196198
differentiable_args = tuple(
@@ -235,7 +237,7 @@ def backward(ctx, *grad_outputs):
235237
torch_function.__annotations__ = torch_function.__annotations__.copy()
236238
for name, annotation in torch_function.__annotations__.items():
237239
if isinstance(annotation, type) and issubclass(annotation, jax.Array):
238-
torch_function.__annotations__[name] = torch.Tensor
240+
torch_function.__annotations__[name] = Tensor
239241

240242
return torch_function
241243

0 commit comments

Comments
 (0)