Apply function (jnp.var) on all sub-segments of a 1D array (TracerIntegerConversionError) #15557
-
Hi, Thanks for the amazing work. I find it very useful. I have a problem where I need to compute the variance of all subsegments of a 1D array, i.e. Can you think of a way to make this code work (the second attempt mostly)? import jax.numpy as jnp
import jax.random as jrd
from jax import lax
from jax.typing import ArrayLike as JArrayLike
# generate dummy signal
n_samples = 100
seed = 123
key = jrd.PRNGKey(seed)
signal = jrd.normal(key=key, shape=(n_samples,))
# Compute the variance of all subsegments
# First try, naive numpy translation.
# does not work because of assignment
res = jnp.zeros((n_samples + 1, n_samples + 1))
for start in range(n_samples):
for end in range(start + 1, n_samples + 1):
res[start, end] = signal[start:end].var() # assignment
# Second try, not sure why this does not work.
# does not work, raise a TracerIntegerConversionError:
# TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Tracedwith with
# val = Array([1, 2, 3, ..., 1, 2, 1], dtype=int32)
# batch_dim = 0
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
def get_single_var(start: jnp.int32, end: jnp.int32) -> jnp.float32:
sub = lax.dynamic_slice_in_dim(
signal, start_index=start, slice_size=end - start
)
return jnp.var(sub)
row_indexes, col_indexes = jnp.triu_indices(n_samples + 1, k=1)
res = jnp.vectorize(get_single_var)(row_indexes, col_indexes) Cheers, Charles |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
After looking at #2323, I suspect that this task might not be efficiently solved with jax. But I'd like your opinion on that. |
Beta Was this translation helpful? Give feedback.
-
Hi - thanks for the question. Working with algorithms that involve dynamically-shaped arrays sometimes requires some thought, but is often still possible to do efficiently. Your first attempt is running into JAX's immutable arrays, but you can use the functional in-place update syntax intead (see Sharp Bits: In Place Updates): import jax.numpy as jnp
from jax import random
n_samples = 4
seed = 123
key = random.PRNGKey(seed)
signal = random.normal(key=key, shape=(n_samples,))
def f_loop(signal):
n_samples = len(signal)
res = jnp.zeros((n_samples + 1, n_samples + 1))
for start in range(n_samples):
for end in range(start + 1, n_samples + 1):
res = res.at[start, end].set(signal[start: end].var())
return res
f_loop(signal)
If you'd like to do the same thing more efficiently using def dynamic_var(x, start, end):
"""Compute x[start:end].var() with dynamic start and end"""
i = jnp.arange(len(x))
x = jnp.where((i >= start) & (i < end), x, 0)
return jnp.sum(x ** 2) / (end - start) - jnp.sum(x) ** 2 / (end - start) ** 2
def f_vmap(signal):
n_samples = len(signal)
i, j = jnp.triu_indices(n_samples + 1, k=2)
res = jnp.zeros((n_samples + 1, n_samples + 1))
return res.at[i, j].set(jax.vmap(dynamic_var, (None, 0, 0))(signal, i, j))
f_vmap(signal)
This second approach should be much faster than the simpler approach based on nested for-loops, particularly as |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. Working with algorithms that involve dynamically-shaped arrays sometimes requires some thought, but is often still possible to do efficiently.
Your first attempt is running into JAX's immutable arrays, but you can use the functional in-place update syntax intead (see Sharp Bits: In Place Updates):