-
Hi, my problem is, during an optimization, I need to truncate the singular values at different. This post has been edited to include the following minimal reproducible example: import jax
from functools import partial
from jax import jit
import jax.numpy as jnp
import jax.numpy.linalg as la
import jaxopt as opt
@jit
def compute_idx(S, *, percent=.90):
cum_sum = jnp.cumsum(S) # total value
total_sum = cum_sum[-1] # largest value
# index of where to insert value
idx = jnp.searchsorted(cum_sum, total_sum * percent) + 1
return idx
@partial(jit, static_argnums=(1,))
def lraTruncated(X, idx):
U, S, Vh = la.svd(X, full_matrices=False, compute_uv=True)
S_diag = jnp.diag(S) # create diagonal matrix
S_diag_trunc = jax.lax.dynamic_slice(S_diag, (0, 0), (idx, idx)) # truncate values
U_trunc = jax.lax.dynamic_slice(U, (0, 0), (U.shape[0], idx))
Vh_trunc = jax.lax.dynamic_slice(Vh, (0, 0), (idx, Vh.shape[1]))
return U_trunc @ S_diag_trunc @ Vh_trunc # recombine
# create a matrix and random initialization matrix
key = jax.random.PRNGKey(42)
matrix = jax.random.normal(key, (100, 300));
init = jnp.zeros(matrix.shape )
# example objective function
def objective(X, A):
idx = compute_idx(A)
approx = lraTruncated(A, idx)
return la.norm(approx-X, ord='fro')
# lambda of objective function
obj = lambda x: objective(matrix, x)
# optimization loop starts
optimizer = opt.BFGS(fun=obj, value_and_grad=False, stepsize=1e-3, maxiter=100_000)
# Initialize the optimizer state
state = optimizer.init_state(init_params=init)
# Run optimization
result = optimizer.run(init_params=init) Where the first snippet gets the point of truncation and the second performs the SVD.
Any reason why? How can I fix this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
You cannot pass an array as an argument that is marked static. Have you tried passing a Python integer instead? i.e instead of this: result = lraTruncated(x, indices[0]) # if indices is a jax Array, then indices[0] is a scalar-shaped JAX array You could write something like this: result = lraTruncated(x, 3) If that doesn't help, then I suspect we'll need more information on how exactly you're using your function. Consider editing your question to add a minimal reproducible example. |
Beta Was this translation helpful? Give feedback.
You cannot run the code as written in your example.
idx
is a dynamic value within the optimizer, and it cannot be converted to a static value, nor can an array of dynamic shape be created (see JAX sharp bits: dynamic shapes).To proceed, you'll have to figure out how to create an objective function that relies only on statically-shaped arrays. For example, you could do this by zeroing out the singular values beyond the desired index, while keeping the arrays their original shape: