Skip to content
Discussion options

You must be logged in to vote

Currently JAX doesn't support dynamic shape

n = jnp.int32(10**(3+i*4/50))
xi = jax.random.uniform(subkey, (n,), minval=-3,maxval=3)

Here is a hacky way to avoid for-loop:

import jax
import jax.numpy as jnp

def prob(x):
    return (8 / 0.5) * jnp.exp(-0.5 * (x / 0.5) ** 2) + (2 / 0.1) * jnp.exp(-0.5 * ((x - 1.5) / 0.1) ** 2)

def phi(x):
    return x ** 2/8


@jax.jit # this function jit compatible
def mc_multiple_n(ns, segment_ids, key):
    xs = jax.random.uniform(key, (len(segment_ids),))
    phis = jax.vmap(phi)(xs)
    probs = jax.vmap(prob)(xs)
    return jax.ops.segment_sum(phis * probs, segment_ids, len(ns)) / jax.ops.segment_sum(probs, segment_ids, len(ns))


  
ns = jnp.int32(10 **

Replies: 1 comment 8 replies

Comment options

You must be logged in to vote
8 replies
@YouJiacheng
Comment options

@YouJiacheng
Comment options

@jecampagne
Comment options

@YouJiacheng
Comment options

@jecampagne
Comment options

Answer selected by jecampagne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants