Skip to content
Discussion options

You must be logged in to vote

You cannot run the code as written in your example. idx is a dynamic value within the optimizer, and it cannot be converted to a static value, nor can an array of dynamic shape be created (see JAX sharp bits: dynamic shapes).

To proceed, you'll have to figure out how to create an objective function that relies only on statically-shaped arrays. For example, you could do this by zeroing out the singular values beyond the desired index, while keeping the arrays their original shape:

@jit
def lraTruncated(X, idx):
    U, S, Vh = la.svd(X, full_matrices=False, compute_uv=True)
    S = jnp.where(jnp.arange(len(S)) < idx, S, 0)
    return U @ jnp.diag(S) @ Vh

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@skyetomez
Comment options

@skyetomez
Comment options

@jakevdp
Comment options

Answer selected by skyetomez
@skyetomez
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants