How to avoid a for-loop w/o throwing an error... #10335
-
Hello, import jax.numpy as jnp
def prob(x, norm=1):
return ((8/0.5)*jnp.exp(-0.5*(x/0.5)**2)+
(2/0.1)*jnp.exp(-0.5*((x-1.5)/0.1)**2))/norm
def phi(x):
return x**2/8
def f(x,norm=1):
return phi(x)*prob(x,norm)
def body_fun(carry,i):
key = carry
key, subkey = jax.random.split(key)
n = jnp.int32(10**(3+i*4/50))
xi = jax.random.uniform(subkey, (n,), minval=-3,maxval=3)
Z_hat = jnp.sum(prob(xi))
I0 = jnp.sum(f(xi))/Z_hat
carry = subkey
return carry, I0
init = jax.random.PRNGKey(20)
_,info = jax.lax.scan(body_fun, init, jnp.arange(len(Ns))) I guess you have recognized the an intregration of
Trying the suggested correstion, leads to over errors, and I have tried different ways of passing the number of samples but nothing leads to a correct code, the same using fori_loop... A simple code which is working is
But I was looking to a more JAXy code ie w/o the for-loop. Any help is welcome. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 8 replies
-
Currently JAX doesn't support dynamic shape n = jnp.int32(10**(3+i*4/50))
xi = jax.random.uniform(subkey, (n,), minval=-3,maxval=3) Here is a hacky way to avoid for-loop: import jax
import jax.numpy as jnp
def prob(x):
return (8 / 0.5) * jnp.exp(-0.5 * (x / 0.5) ** 2) + (2 / 0.1) * jnp.exp(-0.5 * ((x - 1.5) / 0.1) ** 2)
def phi(x):
return x ** 2/8
@jax.jit # this function jit compatible
def mc_multiple_n(ns, segment_ids, key):
xs = jax.random.uniform(key, (len(segment_ids),))
phis = jax.vmap(phi)(xs)
probs = jax.vmap(prob)(xs)
return jax.ops.segment_sum(phis * probs, segment_ids, len(ns)) / jax.ops.segment_sum(probs, segment_ids, len(ns))
ns = jnp.int32(10 ** (3 + jnp.arange(10) * 4 / 50))
segment_ids = jnp.cumsum(jnp.bincount(jnp.cumsum(ns)))[:-1]
key = jax.random.PRNGKey(0)
print(mc_multiple_n(ns, segment_ids, key))
# [0.02604677 0.02502303 0.02473594 0.02428124 0.024074 0.0247458 0.02317728 0.02432697 0.02386729 0.02397611] |
Beta Was this translation helpful? Give feedback.
Currently JAX doesn't support dynamic shape
Here is a hacky way to avoid for-loop: