-
This is probably related to #9804 but I am not sure how to progress from there. I am trying to wrap a custom linear operator for use with
The crux of the wrapping is as such: def j2t(x_jax):
# Convert a jax array to a torch tensor
return torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
def t2j(x_torch):
# Convert a torch tensor to a jax array
return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch.contiguous()))
A_op_torch_as_jax = lambda x_jax : t2j( A_torch @ j2t(x_jax) )
x_res_jax, exit_code = jax.scipy.sparse.linalg.cg(A_op_torch_as_jax, B_jax) A full working example is provided on colab and below: Click to expand code
import traceback
import jax
import jax.numpy as jnp
import jax.scipy.sparse.linalg
from jax import dlpack as jax_dlpack
import jaxlib
print(f'Running JAX version: {jax.__version__}')
#jaxdevice = jax.devices("cpu")[0]
jaxdevice = jax.devices()[0]
print('Running JAX on ' + str(jaxdevice))
import torch
from torch.utils import dlpack as torch_dlpack
print(f'Running PyTorch version: {torch.__version__}')
torchdevice = torch.device('cpu')
if torch.cuda.is_available():
torchdevice = torch.device('cuda')
print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda')))
print('Running torch on ' + str(torchdevice))
import numpy as onp
def j2t(x_jax):
# Convert a jax array to a torch tensor
# See https://github.com/lucidrains/jax2torch/blob/main/jax2torch/jax2torch.py
#if not isinstance(x_jax, jaxlib.xla_extension.DeviceArray ):
# print("\nExpected a JAX DeviceArray type, got x_jax:",type(x_jax))
x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
return x_torch
def t2j(x_torch):
# Convert a torch tensor to a jax array
# See https://github.com/lucidrains/jax2torch/blob/main/jax2torch/jax2torch.py
x_torch = x_torch.contiguous() # https://github.com/google/jax/issues/8082
x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
return x_jax
rng = onp.random.default_rng()
A_n = rng.random((4, 4), dtype=onp.float32)
# Make A SPD
A_n = (A_n + A_n.T) + 2*4*onp.eye(4, dtype=onp.float32)
B_n = rng.standard_normal((4, 2), dtype=onp.float32)
x_j_from_n, exit_code = jax.scipy.sparse.linalg.cg(A_n, B_n)
print(f"x_j_from_n:\n{x_j_from_n}\nExit code: {exit_code}\n")
A_j = jax.device_put(A_n, device=jaxdevice)
B_j = jax.device_put(B_n, device=jaxdevice)
x_j_from_j, exit_code = jax.scipy.sparse.linalg.cg(A_j, B_j)
print(f"x_j_from_j:\n{x_j_from_j}\nExit code: {exit_code}\n")
A_op_n = lambda x : A_n @ x
x_j_from_opn, exit_code = jax.scipy.sparse.linalg.cg(A_op_n, B_n)
print(f"x_j_from_opn:\n{x_j_from_opn}\nExit code: {exit_code}\n")
A_t = torch.from_numpy(A_n).to(torchdevice)
A_op_t = lambda x : t2j( A_t @ j2t(x) )
print(f"A_n @ B_n = {A_n @ B_n}")
print(f"A_op_t( B_j ) = {A_op_t( B_j )}")
try:
x_j_from_opt, exit_code = jax.scipy.sparse.linalg.cg(A_op_t, B_j)
print(f"x_j_from_opt:\n{x_j_from_opt}\nExit code: {exit_code}\n")
except Exception as e:
print("cg(A_op_t, B_j) failed with ", e)
print(traceback.format_exc())
try:
x_j_from_opt, exit_code = jax.scipy.sparse.linalg.cg(jax.tree_util.Partial(A_op_t), B_j)
print(f"x_j_from_opt:\n{x_j_from_opt}\nExit code: {exit_code}\n")
except Exception as e:
print("cg(jax.tree_util.Partial(A_op_t), B_j) failed with ", e)
print(traceback.format_exc()) leading to
Suggestion on how to fix this would be great. Also, if helpful, the same error occurs if attempting to use |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi - thanks for the question! This is expected behavior. JAX transformations work by replacing JAX arrays with tracers (you get more intuition for this in How To Think In JAX). It's a really nice abstraction, but unfortunately it means that within transformed functions you can normally only call JAX code, and not external code like pytorch. There is some new functionality to allow this sort of non-jax call using As for a suggested fix, my main suggestion would be to stick to using JAX implementations within JAX's solvers (or perhaps, depending on your goals, you could find similar functionality implemented entirely in PyTorch). Hope that helps! |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! This is expected behavior. JAX transformations work by replacing JAX arrays with tracers (you get more intuition for this in How To Think In JAX). It's a really nice abstraction, but unfortunately it means that within transformed functions you can normally only call JAX code, and not external code like pytorch.
There is some new functionality to allow this sort of non-jax call using
jax.pure_callback
, but be aware that this requires a host sync and so e.g. interoperating between JAX and pytorch on a GPU will incur major performance penalties due to the movement of data to and from the host. I'm not sure if there's any example-driven documentation forpure_cal…