Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question. Working with algorithms that involve dynamically-shaped arrays sometimes requires some thought, but is often still possible to do efficiently.

Your first attempt is running into JAX's immutable arrays, but you can use the functional in-place update syntax intead (see Sharp Bits: In Place Updates):

import jax.numpy as jnp
from jax import random

n_samples = 4
seed = 123
key = random.PRNGKey(seed)
signal = random.normal(key=key, shape=(n_samples,))


def f_loop(signal):
  n_samples = len(signal)
  res = jnp.zeros((n_samples + 1, n_samples + 1))
  for start in range(n_samples):
    for end in range(start + 1, n_samples + 1):
      res = res.at[start, end].set(si…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@deepcharles
Comment options

Answer selected by deepcharles
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants