DeviceArray as static argument raise non-hashable error #13086
Answered
by
mattjj
yiminghwang
asked this question in
Q&A
-
Hi, there, I feed the DeviceArray to the function with jit, the simplified one shows as follows and raises the following error: from jax import jit
import jax.numpy as jnp
from functools import partial
@partial(jit, static_argnums=(1,))
def f(x, v):
return x@v
def main():
A = jnp.ones((5, 5))
x = jnp.array([1, 2, 3, 4, 5])
y = f(A, x)
print("end")
if __name__ == '__main__':
main()
Does anyone know how to fix it? |
Beta Was this translation helpful? Give feedback.
Answered by
mattjj
Nov 3, 2022
Replies: 1 comment 1 reply
-
Thanks for the question! Take a look at this comment. The thread (especially just above there) has an explanation for why this error is raised. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
mattjj
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the question!
Take a look at this comment. The thread (especially just above there) has an explanation for why this error is raised.