4444try :
4545 import torch
4646 import torch .utils .dlpack
47+ from torch import Tensor
4748
4849 TORCH_AVAILABLE = True
4950except ImportError :
51+ Tensor = None
5052 TORCH_AVAILABLE = False
5153
5254T = TypeVar ("T" )
@@ -65,7 +67,7 @@ def check_torch_available() -> None:
6567 raise RuntimeError (msg )
6668
6769
68- def jax_array_to_torch_tensor (jax_array : jax .Array ) -> torch . Tensor :
70+ def jax_array_to_torch_tensor (jax_array : jax .Array ) -> Tensor :
6971 """
7072 Convert from JAX array to Torch tensor via mutual DLPack support.
7173
@@ -84,7 +86,7 @@ def jax_array_to_torch_tensor(jax_array: jax.Array) -> torch.Tensor:
8486 return torch .utils .dlpack .from_dlpack (jax .dlpack .to_dlpack (jax_array ))
8587
8688
87- def torch_tensor_to_jax_array (torch_tensor : torch . Tensor ) -> jax .Array :
89+ def torch_tensor_to_jax_array (torch_tensor : Tensor ) -> jax .Array :
8890 """
8991 Convert from Torch tensor to JAX array via mutual DLPack support.
9092
@@ -117,7 +119,7 @@ def torch_tensor_to_jax_array(torch_tensor: torch.Tensor) -> jax.Array:
117119
118120def tree_map_jax_array_to_torch_tensor (
119121 jax_pytree : PyTree [jax .Array ],
120- ) -> PyTree [torch . Tensor ]:
122+ ) -> PyTree [Tensor ]:
121123 """
122124 Convert from a pytree with JAX arrays to corresponding pytree with Torch tensors.
123125
@@ -135,7 +137,7 @@ def tree_map_jax_array_to_torch_tensor(
135137
136138
137139def tree_map_torch_tensor_to_jax_array (
138- torch_pytree : PyTree [torch . Tensor ],
140+ torch_pytree : PyTree [Tensor ],
139141) -> PyTree [jax .Array ]:
140142 """
141143 Convert from a pytree with Torch tensors to corresponding pytree with JAX arrays.
@@ -148,7 +150,7 @@ def tree_map_torch_tensor_to_jax_array(
148150
149151 """
150152 return tree_map (
151- lambda t : torch_tensor_to_jax_array (t ) if isinstance (t , torch . Tensor ) else t ,
153+ lambda t : torch_tensor_to_jax_array (t ) if isinstance (t , Tensor ) else t ,
152154 torch_pytree ,
153155 )
154156
@@ -176,7 +178,6 @@ def wrap_as_torch_function(
176178 Wrapped function callable from Torch.
177179
178180 """
179- check_torch_available ()
180181 sig = signature (jax_function )
181182 if differentiable_argnames is None :
182183 differentiable_argnames = tuple (
@@ -191,6 +192,7 @@ def wrap_as_torch_function(
191192
192193 @wraps (jax_function )
193194 def torch_function (* args , ** kwargs ):
195+ check_torch_available ()
194196 bound_args = sig .bind (* args , ** kwargs )
195197 bound_args .apply_defaults ()
196198 differentiable_args = tuple (
@@ -235,7 +237,7 @@ def backward(ctx, *grad_outputs):
235237 torch_function .__annotations__ = torch_function .__annotations__ .copy ()
236238 for name , annotation in torch_function .__annotations__ .items ():
237239 if isinstance (annotation , type ) and issubclass (annotation , jax .Array ):
238- torch_function .__annotations__ [name ] = torch . Tensor
240+ torch_function .__annotations__ [name ] = Tensor
239241
240242 return torch_function
241243
0 commit comments