jax.numpy.linalg.eig #8813
Unanswered
MarcinPlodzien
asked this question in
Q&A
Replies: 2 comments
-
Try
This seems to work fine. |
Beta Was this translation helpful? Give feedback.
0 replies
-
There are two things going on here: |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I have encountered problem with simple eig on GPU. With the following code:
import jax.numpy as np A = np.array([[1.,2.],[2.,3.]]) D, V = np.linalg.eig(A)
I obtain the following error:
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/_src/numpy/linalg.py", line 291, in eig
return lax_linalg.eig(a, compute_left_eigenvectors=False)
JaxStackTraceBeforeTransformation: TypeError: eig_translation_rule() takes 2 positional arguments but 4 positional arguments (and 2 keyword-only arguments) were given
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/marcin/workspace/cupy_test/cupy_test.py", line 52, in
D, V = np.linalg.eig(A)
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/_src/numpy/linalg.py", line 291, in eig
return lax_linalg.eig(a, compute_left_eigenvectors=False)
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/_src/lax/linalg.py", line 85, in eig
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/core.py", line 272, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/core.py", line 624, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/_src/lax/linalg.py", line 385, in eig_impl
xla.apply_primitive(eig_p, operand,
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/interpreters/xla.py", line 416, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/_src/util.py", line 187, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/_src/util.py", line 180, in cached
return f(*args, **kwargs)
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/interpreters/xla.py", line 439, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/interpreters/xla.py", line 759, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/interpreters/xla.py", line 829, in lower_xla_callable
out_nodes = jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
File "/home/marcin/anaconda3/lib/python3.9/site-packages/jax/interpreters/xla.py", line 577, in jaxpr_subcomp
ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
TypeError: eig_translation_rule() takes 2 positional arguments but 4 positional arguments (and 2 keyword-only arguments) were given
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Can anyone say what can be wrong?
However, jax SVD works fine.
Thanks for any help and suggestions!
Beta Was this translation helpful? Give feedback.
All reactions