Skip to content

Commit 63587b4

Browse files
committed
Added lotvol_model on regular scale
1 parent 173375e commit 63587b4

File tree

5 files changed

+474
-101
lines changed

5 files changed

+474
-101
lines changed

examples/lotvol.ipynb

Lines changed: 155 additions & 34 deletions
Large diffs are not rendered by default.

examples/pgnet.ipynb

Lines changed: 82 additions & 57 deletions
Large diffs are not rendered by default.

src/pfjax/experimental/models/__init__.py

Whitespace-only changes.
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
Lotka-Volterra predator-prey model on the regular scale.
3+
4+
The model is:
5+
6+
```
7+
x_m0 \propto 1
8+
H_mt ~ N(H_{m,t-1} + (alpha H_{m,t-1} - beta H_{m,t-1}L_{m,t-1}) dt/m,
9+
sigma_H^2 dt/m)
10+
L_mt ~ N(L_{m,t-1} + (-gamma L_{m,t-1} + delta H_{m,t-1}L_{m,t-1}) dt/m,
11+
sigma_L^2 dt/m)
12+
y_t ~ N(x_{m,mt}, diag(tau_H^2, tau_L^2) )
13+
```
14+
15+
- Model parameters: `theta = (alpha, beta, gamma, delta, sigma_H, sigma_L, tau_H, tau_L)`.
16+
- Global constants: `dt` and `n_res`, i.e., `m`.
17+
- State dimensions: `n_state = (n_res, 2)`.
18+
- Measurement dimensions: `n_meas = 2`.
19+
20+
**Notes:**
21+
22+
- The measurement `y_t` corresponds to `x_t = (x_{m,(t-1)m+1}, ..., x_{m,tm})`, i.e., aligns with the last element of `x_t`.
23+
- The prior is such that `p(x_0 | y_0, theta)` is given by:
24+
25+
```
26+
x_{m,n} = 0 for n = -m+1, ..., -1,
27+
x_{m0} ~ TruncatedNormal( y_0, diag(tau_H^2, tau_L^2) ),
28+
```
29+
30+
where
31+
32+
```
33+
z ~ TruncatedNormal(mu, diag(sigma^2)) <=>
34+
z = mu + diag(sigma) Z_0, Z_0 ~iid N(0,1) truncated at -mu.
35+
```
36+
37+
"""
38+
39+
import jax
40+
import jax.numpy as jnp
41+
import jax.scipy as jsp
42+
from jax import random
43+
from jax import lax
44+
from pfjax import sde as sde
45+
46+
# --- helper functions ---------------------------------------------------------
47+
48+
49+
def lotvol_drift(x, dt, theta):
50+
r"""
51+
Calculates the SDE drift function.
52+
"""
53+
alpha = theta[0]
54+
beta = theta[1]
55+
gamma = theta[2]
56+
delta = theta[3]
57+
return x + jnp.array([alpha * x[0] - beta * x[0] * x[1],
58+
-gamma * x[1] + delta * x[0] * x[1]]) * dt
59+
60+
61+
# --- main functions -----------------------------------------------------------
62+
63+
class RegLotVolModel(sde.SDEModel):
64+
def __init__(self, dt, n_res):
65+
r"""
66+
Class constructor for the Lotka-Volterra model.
67+
68+
Args:
69+
dt: SDE interobservation time.
70+
n_res: SDE resolution number. There are `n_res` latent variables per observation, equally spaced with interobservation time `dt/n_res`.
71+
"""
72+
# creates "private" variables self._dt and self._n_res
73+
super().__init__(dt, n_res, diff_diag=True)
74+
# self.dt = dt
75+
# self.n_res = n_res
76+
# the following variable is mainly used for testing, i.e.,
77+
# in the `_for` versions of certain methods.
78+
# it does contain the number of SDE dimensions, which is used
79+
# outside of testing, but can be circumvented by pulling shape
80+
# from input arguments. however, this may fail less informatively
81+
# than if using prespecified SDE dimensions...
82+
self._n_state = (self._n_res, 2)
83+
84+
def drift(self, x, theta):
85+
r"""
86+
Calculates the SDE drift function.
87+
"""
88+
alpha = theta[0]
89+
beta = theta[1]
90+
gamma = theta[2]
91+
delta = theta[3]
92+
return jnp.array([alpha * x[0] - beta * x[0] * x[1],
93+
-gamma * x[1] + delta * x[0] * x[1]])
94+
95+
def diff(self, x, theta):
96+
r"""
97+
Calculates the SDE diffusion function.
98+
"""
99+
return theta[4:6]
100+
101+
def state_lpdf_for(self, x_curr, x_prev, theta):
102+
r"""
103+
Calculates the log-density of `p(x_curr | x_prev, theta)`.
104+
105+
For-loop version for testing.
106+
107+
Args:
108+
x_curr: State variable at current time `t`.
109+
x_prev: State variable at previous time `t-1`.
110+
theta: Parameter value.
111+
Returns:
112+
The log-density of `p(x_curr | x_prev, theta)`.
113+
"""
114+
dt_res = self._dt/self._n_res
115+
x0 = jnp.append(jnp.expand_dims(
116+
x_prev[self._n_res-1], axis=0), x_curr[:self._n_res-1], axis=0)
117+
x1 = x_curr
118+
sigma = theta[4:6] * jnp.sqrt(dt_res)
119+
lp = jnp.array(0.0)
120+
for t in range(self._n_res):
121+
lp = lp + jnp.sum(jsp.stats.norm.logpdf(
122+
x1[t],
123+
loc=lotvol_drift(x0[t], dt_res, theta),
124+
scale=sigma
125+
))
126+
return lp
127+
128+
def state_sample_for(self, key, x_prev, theta):
129+
r"""
130+
Samples from `x_curr ~ p(x_curr | x_prev, theta)`.
131+
132+
For-loop version for testing.
133+
134+
Args:
135+
key: PRNG key.
136+
x_prev: State variable at previous time `t-1`.
137+
theta: Parameter value.
138+
139+
Returns:
140+
Sample of the state variable at current time `t`: `x_curr ~ p(x_curr | x_prev, theta)`.
141+
"""
142+
dt_res = self._dt/self._n_res
143+
sigma = theta[4:6] * jnp.sqrt(dt_res)
144+
x_curr = jnp.zeros(self._n_state)
145+
x_state = x_prev[self._n_res-1]
146+
for t in range(self._n_res):
147+
key, subkey = random.split(key)
148+
x_state = lotvol_drift(x_state, dt_res, theta) + \
149+
random.normal(subkey, (self._n_state[1],)) * sigma
150+
x_curr = x_curr.at[t].set(x_state)
151+
return x_curr
152+
153+
def meas_lpdf(self, y_curr, x_curr, theta):
154+
r"""
155+
Log-density of `p(y_curr | x_curr, theta)`.
156+
157+
Args:
158+
y_curr: Measurement variable at current time `t`.
159+
x_curr: State variable at current time `t`.
160+
theta: Parameter value.
161+
162+
Returns
163+
The log-density of `p(y_curr | x_curr, theta)`.
164+
"""
165+
tau = theta[6:8]
166+
return jnp.sum(
167+
jsp.stats.norm.logpdf(y_curr,
168+
loc=x_curr[-1], scale=tau)
169+
)
170+
171+
def meas_sample(self, key, x_curr, theta):
172+
r"""
173+
Sample from `p(y_curr | x_curr, theta)`.
174+
175+
Args:
176+
key: PRNG key.
177+
x_curr: State variable at current time `t`.
178+
theta: Parameter value.
179+
180+
Returns:
181+
Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
182+
"""
183+
tau = theta[6:8]
184+
return x_curr[-1] + \
185+
tau * random.normal(key, (self._n_state[1],))
186+
187+
def pf_init(self, key, y_init, theta):
188+
r"""
189+
Importance sampler for `x_init`.
190+
191+
See file comments for exact sampling distribution of `p(x_init | y_init, theta)`, i.e., we have a "perfect" importance sampler with `logw = CONST(theta)`.
192+
193+
Args:
194+
key: PRNG key.
195+
y_init: Measurement variable at initial time `t = 0`.
196+
theta: Parameter value.
197+
198+
Returns:
199+
- x_init: A sample from the proposal distribution for `x_init`.
200+
- logw: The log-weight of `x_init`.
201+
"""
202+
tau = theta[6:8]
203+
key, subkey = random.split(key)
204+
x_init = y_init + tau * random.truncated_normal(
205+
subkey,
206+
lower=-y_init/tau,
207+
upper=jnp.inf,
208+
shape=(self._n_state[1],)
209+
)
210+
logw = jnp.sum(jsp.stats.norm.logcdf(y_init/tau))
211+
return \
212+
jnp.append(jnp.zeros((self._n_res-1,) + x_init.shape),
213+
jnp.expand_dims(x_init, axis=0), axis=0), \
214+
logw
215+
216+
def is_valid(self, x, theta):
217+
"""
218+
Checks whether SDE observations are valid.
219+
220+
Args:
221+
x: SDE variables. A vector of size `n_dims`.
222+
theta: Parameter value.
223+
224+
Returns:
225+
Whether or not `x>=0`.
226+
"""
227+
return x >= 0

src/pfjax/sde.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def is_valid_state(self, x, theta):
198198
"""
199199
valid_x = jax.vmap(self.is_valid, in_axes=(0, None))(x, theta)
200200
nan_x = jnp.any(jnp.isnan(x), axis=1)
201-
return jnp.alltrue(valid_x, where=~nan_x) and jnp.alltrue(~nan_x)
201+
return jnp.alltrue(valid_x, where=~nan_x) & jnp.alltrue(~nan_x)
202+
# return jnp.alltrue(valid_x) and jnp.alltrue(~nan_x)
202203

203204
def state_lpdf(self, x_curr, x_prev, theta):
204205
"""
@@ -353,9 +354,8 @@ def pf_step(self, key, x_prev, y_curr, theta):
353354
x_curr = self.state_sample(key, x_prev, theta)
354355
logw = lax.cond(
355356
self.is_valid_state(x_curr, theta),
356-
lambda _x: self.meas_lpdf(y_curr, x_curr, theta),
357-
lambda _x: -jnp.inf,
358-
0.0
357+
lambda: self.meas_lpdf(y_curr, x_curr, theta),
358+
lambda: -jnp.inf
359359
)
360360
# logw = self.meas_lpdf(y_curr, x_curr, theta)
361361
return x_curr, logw
@@ -436,12 +436,12 @@ def scan_fun(carry, n):
436436
theta=theta
437437
)
438438
logw = logw + self.meas_lpdf(y_curr, x_prop, theta) - last["lp"]
439-
logw = lax.cond(
440-
self.is_valid_state(x_prop, theta),
441-
lambda _x: logw,
442-
lambda _x: -jnp.inf,
443-
0.0
444-
)
439+
# logw = lax.cond(
440+
# self.is_valid_state(x_prop, theta),
441+
# lambda _x: logw,
442+
# lambda _x: -jnp.inf,
443+
# 0.0
444+
# )
445445
return x_prop, logw
446446

447447
def bridge_prop_for(self, key, x_prev, y_curr, theta, Y, A, Omega):

0 commit comments

Comments
 (0)