-
I think this is a long shot, but: is there any way to convert a numpy function into a jax.numpy function? Just doing the usual trick of changing
by
is not satisfactory because requires me to modify external libraries, and furthermore, sometimes I get errors such as, in
Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
No, unfortunately there's no way to do this in general. JAX is unable to implement the entire numpy API... as a simple example, numpy will happily create and manipulate arrays of strings. However XLA (JAX's backend) has no string type, so it's not possible to express the corresponding operations in JAX. You should think of |
Beta Was this translation helpful? Give feedback.
No, unfortunately there's no way to do this in general. JAX is unable to implement the entire numpy API... as a simple example, numpy will happily create and manipulate arrays of strings. However XLA (JAX's backend) has no string type, so it's not possible to express the corresponding operations in JAX.
You should think of
jax.numpy
as a numpy-like API, but not as a 100% faithful drop-in replacement for numpy.