Skip to content
Discussion options

You must be logged in to vote

The level is not particularly relevant here – the issue is that your function is being called within JIT-compiled code, and so you're attempting to convert a traced variable to a numpy array, and this is not possible. You can get the same error by doing something like this:

import jax
import numpy as np
import jax.numpy as jnp

@jax.jit
def call_numpy_sin(x):
  x_np = np.asarray(x)
  out_np = np.sin(x_np)
  return jnp.asarray(out_np)

call_numpy_sin(jnp.arange(10))

To understand what's going on here, you might read through How To Think In JAX, in particular the JIT Mechanics section.

The short version of the issue is, you can't convert a traced variable to a numpy array.

I don't know of a…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@jmsull
Comment options

@jmsull
Comment options

Answer selected by jmsull
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants