How to replace increasing subsequences by their averages? #8862
-
Consider an array:
It has two (strictly) increasing subsequences So a
The routine identifies every (strictly) increasing subsequence. It computes the average for each such subsequence. Then it replaces each subsequence with its average. Both the number of such subsequences and lengths of each such subsequence are data-dependent. How can I code such a function efficiently [with JIT support] in JAX? A C or naive implementation is not that difficult. But I am scratching my head due to the difficulties of identifying the dynamic slices and then averaging over the dynamic number of dynamic slices in the problem. This is an intermediate step in the computation of the proximal operator for a weighted sorted l1 norm. Reference: lgorzata Bogdana, Ma, et al. "Statistical estimation and testing via the ordered l1 norm." (2013). |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Here's one way you might do it, using import jax
import jax.numpy as jnp
@jax.jit
def replace_increasing_subsequences_with_averages(x):
mask = jnp.zeros(len(x), dtype=bool).at[1:].set(jnp.diff(x) > 0)
segment_ids = jnp.cumsum(~mask) - 1
sums = jax.ops.segment_sum(x, segment_ids, num_segments=len(x))
norms = jax.ops.segment_sum(jnp.ones_like(x), segment_ids, num_segments=len(x))
return (sums / norms)[segment_ids]
x = jnp.array([5, 5, 4, 4, 6, 8, 10, 5, 4, 3, 3, 2, 2, 3, 4, 7, 4, 2, 1, 1])
print(replace_increasing_subsequences_with_averages(x))
# [5. 5. 4. 7. 7. 7. 7. 5. 4. 3. 3. 2. 4. 4. 4. 4. 4. 2. 1. 1.] |
Beta Was this translation helpful? Give feedback.
Here's one way you might do it, using
jnp.cumsum
over a mask to determine unique IDs for each segment, then computing the mean using two passes ofjax.ops.segment_sum
: