-
Hi everyone ! I try to solve some inverse design problems with the help of Jax. The time need to solve a sparse linear system from Jax is not well adapter to my case. I already have some functions from another Python library that reduce the solving time. But it is not in Jax... So, I need to implement my I tried naively with something like : from jax.config import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')
from jax import grad, jacfwd, jacrev, jacobian, random
import jax.numpy as jnp
import jax
from numpy.linalg import solve # from special_package import super_solving_func
random_key = 42
key = random.PRNGKey(random_key)
M = 10
A = random.normal(key, shape=(M,M))
b = random.normal(key, shape=(M,))
@jax.custom_jvp
def custom_solve(Ad):
x_res = solve(Ad, b)
return jnp.asarray(x_res)
@custom_solve.defjvp
def custom_solve_jvp(primals, tangents):
A, = primals
A_dot, = tangents
primal_out = custom_solve(A)
tangent_out = jnp.asarray(solve(A, -A_dot @ primal_out))
return primal_out, tangent_out
J = jax.jacrev(custom_solve)(A) To make the things more easy in the beginning, I fixed b But when I launch my script I get the final following error
I understand Jax does not like I took some Numpy array with the function To go a deeper on my concept, I actually have an existing algorithm that generates my problem as a sparse matrix I need to solve (actually it is a Any help is welcome ! Thanks in advance, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 6 replies
-
The issue is that within the custom JVP definition, |
Beta Was this translation helpful? Give feedback.
-
So, there are ways you can do this currently by defining a custom primitive and using non-public routines which are not guaranteed to have a stable API. My hope is that we can make this cleaner and more intuitive in the future, but as a proof of concept here is some code that works in the most recent release. Here is an example of defining a new primitive, import numpy as np
import jax.numpy as jnp
import jax
from jax import core
from jax._src import dtypes
from jax.interpreters import ad, batching, mlir
def np_sin(x):
return np_sin_p.bind(x)
np_sin_p = core.Primitive("np_sin")
# The abstract evaluation rule tells us the expected shape & dtype of the output.
def _np_sin_abstract_eval(x):
dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x.dtype))
return core.ShapedArray(x.shape, dtype)
# The impl rule defines how the primitive is evaluated *outside* jit.
def _np_sin_impl(x):
x = np.asarray(x)
np_result = np.sin(x)
dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x.dtype))
return jnp.asarray(np_result, dtype=dtype)
# The lowering rule defines how the primitive is evaluated *within* jit
def _np_sin_lowering(ctx, x):
# Note: callback function must return a tuple of the expected shape & dtype
def sin_callback(x):
out = np.sin(x) # This is a numpy function, called on a numpy array
return (out.astype(ctx.avals_out[0].dtype),)
token = None # Unused in this case
result, token, keepalive = mlir.emit_python_callback(
ctx, sin_callback, token, [x], ctx.avals_in, ctx.avals_out, False)
ctx.module_context.add_keepalive(keepalive)
return result
np_sin_p.def_abstract_eval(_np_sin_abstract_eval)
np_sin_p.def_impl(_np_sin_impl)
mlir.register_lowering(np_sin_p, _np_sin_lowering)
# Since np.sin(x) is rank-polymorphic, the batching rule is easy
batching.defvectorized(np_sin_p)
# jvp must be defined in terms of JAX primitives, so use cos(x) = sin(x + π/2)
ad.defjvp(np_sin_p, lambda g, x: g * np_sin(x + np.pi / 2))
x = jnp.arange(4)
# Evaluate via the impl rule
print(np_sin(x))
# [0. 0.84147096 0.9092974 0.14112 ]
# Evaluate under JIT via the lowering rule
print(jax.jit(np_sin)(x))
# [0. 0.84147096 0.9092974 0.14112 ]
# vmap uses the batching rule
print(jax.vmap(np_sin)(x))
# [0. 0.84147096 0.9092974 0.14112 ]
# grad uses the JVP rule
print(jax.grad(np_sin)(1.0))
# 0.5403023 All the JAX goodies working for a function implemented entirely in numpy! As I mentioned, these APIs are non-public, and in particular |
Beta Was this translation helpful? Give feedback.
So, there are ways you can do this currently by defining a custom primitive and using non-public routines which are not guaranteed to have a stable API. My hope is that we can make this cleaner and more intuitive in the future, but as a proof of concept here is some code that works in the most recent release.
Here is an example of defining a new primitive,
np_sin
, which computes the element-wise sine of an input using a callback tonumpy.sin
. In addition, I define a lowering rule (for compatibility with jit), a batching rule (for compatibility with vmap), and a jvp rule (for compatibility withgrad
and other autodiff transformations):