Problem in porting scipy.linalg.orth with JIT #8042
-
A straightforward port of import jax.numpy as jnp
from jax import jit, lax
from jax.scipy.linalg import svd
def orth(A, rcond=None):
u, s, vh = svd(A, full_matrices=False)
M, N = u.shape[0], vh.shape[1]
if rcond is None:
rcond = jnp.finfo(s.dtype).eps * max(M, N)
tol = jnp.amax(s) * rcond
num = jnp.sum(s > tol, dtype=int)
Q = u[:, :num]
# Q = lax.dynamic_slice(u, (0, 0), (M, num))
return Q
# the following doesn't work
orth_jit = jit(orth, static_argnums=(1,)) The key problem is that number of columns in the Q is data-dependent and JAX doesn't support dynamic slices. Even Is this the reason why it is not included in One possible workaround I was thinking of was to change the return type to a tuple |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Yes, in general JAX does not (yet) support data-dependent shapes inside a The usual workaround in JAX is to return a padded array, i.e., return We have plan for more first-class support for dynamic shapes inside |
Beta Was this translation helpful? Give feedback.
Yes, in general JAX does not (yet) support data-dependent shapes inside a
jit
. Here,num
necessarily depends on the matrix. Outside ajit
it works fine.The usual workaround in JAX is to return a padded array, i.e., return
Q
unsliced together withnum
that tells you how much of the array contains useful data. This requires an API change; sometimes we do that by adding an optional keyword argument that states if you want to exactly mimic the NumPy/SciPy behavior or return the(Q, num)
pair in ajit
-compatible way.We have plan for more first-class support for dynamic shapes inside
jit
but nothing ready at this time.