Skip to content
Discussion options

You must be logged in to vote

Here's one way you might do it, using jnp.cumsum over a mask to determine unique IDs for each segment, then computing the mean using two passes of jax.ops.segment_sum:

import jax
import jax.numpy as jnp

@jax.jit
def replace_increasing_subsequences_with_averages(x):
  mask = jnp.zeros(len(x), dtype=bool).at[1:].set(jnp.diff(x) > 0)
  segment_ids = jnp.cumsum(~mask) - 1
  sums = jax.ops.segment_sum(x, segment_ids, num_segments=len(x))
  norms = jax.ops.segment_sum(jnp.ones_like(x), segment_ids, num_segments=len(x))
  return (sums / norms)[segment_ids]

x = jnp.array([5, 5, 4, 4, 6, 8, 10, 5, 4, 3, 3, 2, 2, 3, 4, 7, 4, 2, 1, 1])
print(replace_increasing_subsequences_with_averages(x))
# [5. …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@shailesh1729
Comment options

Answer selected by shailesh1729
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants