|
| 1 | +import sys |
| 2 | + |
| 3 | +sys.path.insert(0, "Serialized/Fwd") |
| 4 | +sys.path.insert(0, "Serialized/Bwd") |
| 5 | + |
| 6 | + |
| 7 | +import jax |
| 8 | +jax.config.update('jax_enable_x64', True) |
| 9 | +import jax.numpy as jnp |
| 10 | +import blackjax |
| 11 | +import enzyme_ad |
| 12 | +from functools import partial |
| 13 | + |
| 14 | +import logdensityof as lg |
| 15 | +import gl as gl |
| 16 | + |
| 17 | +from logdensityof import run_logdensityof |
| 18 | +from gl import run_gl |
| 19 | + |
| 20 | +lg_inputs = lg.load_inputs() |
| 21 | +gl_inputs = gl.load_inputs() |
| 22 | + |
| 23 | +tpost = lg_inputs[:-1] |
| 24 | +xr = lg_inputs[-1] |
| 25 | + |
| 26 | + |
| 27 | + |
| 28 | + |
| 29 | +jlr = jax.jit(run_logdensityof) |
| 30 | + |
| 31 | +jtpost0 = jnp.array(tpost[0]) |
| 32 | +jtpost1 = jnp.array(tpost[1]) |
| 33 | +jtpost2 = jnp.array(tpost[2]) |
| 34 | +jtpost3 = jnp.array(tpost[3]) |
| 35 | +jtpost4 = jnp.array(tpost[4]) |
| 36 | +jxr = jnp.array(xr) |
| 37 | + |
| 38 | +out = jlr(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, xr) |
| 39 | + |
| 40 | +run_logdensityof(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, jxr) |
| 41 | +run_gl(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, jxr) |
| 42 | + |
| 43 | + |
| 44 | +@jax.custom_vjp |
| 45 | +def f(x): |
| 46 | + out = run_logdensityof(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, x) |
| 47 | + return out[0] |
| 48 | + |
| 49 | +def f_fwd(x): |
| 50 | + j = run_gl(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, x)[0] |
| 51 | + return f(x), (j,) |
| 52 | + |
| 53 | +def f_bwd(res, g): |
| 54 | + j = res[0] |
| 55 | + return (g * j,) |
| 56 | + |
| 57 | +f.defvjp(f_fwd, f_bwd) |
| 58 | + |
| 59 | +logdensity = lambda x: f(**x) |
| 60 | + |
| 61 | +inv_mass_matrix = jnp.ones(len(jxr)) |
| 62 | +initial_position = {"x": jxr} |
| 63 | + |
| 64 | +rng_key, sample_key = jax.random.split(jax.random.PRNGKey(0)) |
| 65 | + |
| 66 | +# adaptation |
| 67 | +warmup = blackjax.window_adaptation(blackjax.nuts, logdensity, progress_bar=True) |
| 68 | +rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3) |
| 69 | +(state, parameters), _ = warmup.run(warmup_key, initial_position, num_steps=1000) |
| 70 | + |
| 71 | + |
| 72 | +def inference_loop(rng_key, kernel, init, nsamples): |
| 73 | + @jax.jit |
| 74 | + def step(state, rng_key): |
| 75 | + state, _ = kernel(rng_key, state) |
| 76 | + return state, state |
| 77 | + |
| 78 | + keys = jax.random.split(rng_key, nsamples) |
| 79 | + _, states = jax.lax.scan(step, init, keys) |
| 80 | + return states |
| 81 | + |
| 82 | +# inference loop |
| 83 | +rng_key, sample_key = jax.random.split(jax.random.PRNGKey(0)) |
| 84 | +kernel = blackjax.nuts(logdensity, **parameters).step |
| 85 | +states = inference_loop(sample_key, kernel, state, nsamples=1000) |
| 86 | + |
| 87 | + |
0 commit comments