TypeError: Argument 'cpu:0' of type <class 'jaxlib.xla_extension.CpuDevice'> is not a valid JAX type. #5821
Replies: 12 comments 1 reply
-
When you call desc = jax.device_put(descriptors,jax.devices()[0])
matches_jit = jit(match_faces3)
%timeit matches_jit(desc) Also, as a side note – JAX's runtime model is asynchronous, so if you're timing operations you should use the %timeit matches_jit(desc).block_until_ready() See Asynchronous Dispatch for more information. |
Beta Was this translation helpful? Give feedback.
-
I tried this previously but got this Exception, so tried that way. Exception: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float32[512])>with<DynamicJaxprTrace(level=0/1)>. This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using |
Beta Was this translation helpful? Give feedback.
-
It is hard to say for sure because the example you gave is incomplete, but I suspect |
Beta Was this translation helpful? Give feedback.
-
I changed cos into a jax function but the same error. Actually this is the original code which I have run using numpy, Im using jax to reduce the time taken, distances = np.empty((len(descriptors), len(database)))
f.write("Descriptors")
f.write(str(descriptors))
time1 = time.time()
for i, desc in enumerate(descriptors):
for j, identity in enumerate(database):
dist = []
for k, id_desc in enumerate(identity[1]):
dist.append(cosine_dist(desc, id_desc))
distances[i][j] = dist[np.argmin(dist)]
time2 = time.time() - time1
print("time2",time2)
f.write("Distances")
f.write(str(distances)) |
Beta Was this translation helpful? Give feedback.
-
Thanks for the further information. If you want further help debugging this, I'd suggest putting together a minimal reproducible example and including it here: that will take the guesswork out of helping find the cause of your error. |
Beta Was this translation helpful? Give feedback.
-
[('8174', descriptors = np.array([[-1.69,4.4,4.27,1.96,2.7,-5.73,-5.41,1.12,2.5,2.09,5.8,-8.7,6.7,4.1,7.4,6.3,9.7,2.4,6.4,3.3]]) This is how the descriptors look like The cosine_func is a normal cos function. |
Beta Was this translation helpful? Give feedback.
-
Thanks - rather than describing your code in words (what does "normal cos function" mean?) it would be more helpful if you could include a complete code snippet, i.e. something that someone could copy and paste into a runtime and see the same results you are seeing. The minimal reproducible example link from my earlier comment is a helpful resource with more details on how to construct such an example. |
Beta Was this translation helpful? Give feedback.
-
import numpy as np
import time,datetime
from scipy.spatial.distance import cosine
import jax
import jax.numpy as jnp
f = open("write_out.txt", 'w+')
def cosine_dist(x, y):
return cosine(x, y) * 0.5
distances = np.empty((len(descriptors), len(database)))
f.write("Descriptors")
f.write(str(descriptors))
time1 = time.time()
for i, desc in enumerate(descriptors):
for j, identity in enumerate(database):
dist = []
for k, id_desc in enumerate(identity[1]):
dist.append(cosine_dist(desc, id_desc))
distances[i][j] = dist[np.argmin(dist)]
time2 = time.time() - time1
print("time2",time2)
f.write("Distances")
f.write(str(distances)) The jax which I tried to implement is distances = np.empty((len(desc), len(database)))
f.write("Descriptors")
f.write(str(desc))
time1 = time.time()
for i, descr in enumerate(desc):
for j, identity in enumerate(database):
dist = []
for k, id_desc in enumerate(identity[1]):
dist.append(cosine_dist(descr, id_desc))
distances[i][j] = dist[jnp.argmin(jnp.asarray(dist))]
time2 = time.time() - time1
print("time2",time2)
distances=distances.tolist()
f.write(str(distances)) When I just changed this Jax time was more, so I tried putting it in a jit function. |
Beta Was this translation helpful? Give feedback.
-
Thanks! This is much more clear. The issue is that from scipy.spatial.distance import cosine
from jax import jit
import jax.numpy as jnp
jit(cosine)(jnp.arange(4), jnp.arange(4)) Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[4])>with<DynamicJaxprTrace(level=0/1)>.
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`. If you want to JIT-compile this computation, you'll have to implement the cosine distance using JAX functions; for example: import numpy as np
import jax.numpy as jnp
from jax import jit
from scipy.spatial.distance import cosine
@jit
def jax_cosine(u, v):
return 1 - jnp.dot(u, v) / (jnp.linalg.norm(u) * jnp.linalg.norm(v))
u = np.random.rand(100)
v = np.random.rand(100)
print(np.allclose(cosine(u, v), jax_cosine(u, v)))
# True Additionally, if you want your implementation to have good performance, you should avoid iteration over array axes in favor of vectorized operations via broadcasting or through tools like vmap. |
Beta Was this translation helpful? Give feedback.
-
Thank you in match_faces3(desc) TypeError: list indices must be integers or slices, not DynamicJaxprTracer @jakevdp |
Beta Was this translation helpful? Give feedback.
-
That makes sense: you're trying to iterate over a traced array. As I mentioned, you should avoid explicit iteration in favor of either vectorized operations or vmap. |
Beta Was this translation helpful? Give feedback.
-
Actually, looking more closely the issue is that |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
TypeError: Argument 'cpu:0' of type <class 'jaxlib.xla_extension.CpuDevice'> is not a valid JAX type.
I encountered this error while trying to run the below code.
Desc is a multidimensional array
I tried everything but couldn't figure out the problem. Please help @mattjj
Originally posted by @Joy-Preetha in #4416 (comment)
Beta Was this translation helpful? Give feedback.
All reactions