Skip to content

Commit a5b5db8

Browse files
committed
Added pgnet no dna model
1 parent 63587b4 commit a5b5db8

File tree

4 files changed

+796
-78
lines changed

4 files changed

+796
-78
lines changed

examples/pgnet.ipynb

Lines changed: 312 additions & 58 deletions
Large diffs are not rendered by default.

src/pfjax/experimental/models/pgnet_model_no_DNA.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(self, dt, n_res, bootstrap=True):
6363
super().__init__(dt, n_res, diff_diag=False)
6464
self._n_state = (self._n_res, 4)
6565
self._K = 10
66+
self._eps = 1e-10
6667
self._bootstrap = bootstrap
6768

6869
def _parse_params(self, params):
@@ -74,15 +75,27 @@ def _parse_params(self, params):
7475
dna0 = params[11]
7576
return theta, tau, dna0
7677

78+
def _expit(self, x):
79+
"""
80+
Inverts the logit function.
81+
"""
82+
x4 = self._K*jnp.exp(x[3])/(jnp.exp(x[3]) + 1)
83+
return jnp.append(jnp.exp(x[:3]), x4)
84+
85+
def _logit(self, x):
86+
return jnp.log(x/(self._K - x))
87+
7788
def _drift(self, x, theta):
7889
"""
7990
Calculate the drift on the original scale.
8091
"""
8192
mu1 = theta[2]*x[3] - theta[6]*x[0]
93+
sigma_max = jnp.where(0 < x[1]*(x[1]-1), x[1]*(x[1]-1), 0)
94+
# sigma_max = x[1]*(x[1]-1)
8295
mu2 = 2*theta[5]*x[2] - theta[7]*x[1] + \
83-
theta[3]*x[0] - theta[4]*x[1]*(x[1]-1)
96+
theta[3]*x[0] - theta[4]*sigma_max
8497
mu3 = theta[1]*(self._K-x[3]) - theta[0]*x[3]*x[2] - \
85-
theta[5]*x[2] + 0.5*theta[4]*x[1]*(x[1]-1)
98+
theta[5]*x[2] + 0.5*theta[4]*sigma_max
8699
mu4 = theta[1]*(self._K-x[3]) - theta[0]*x[3]*x[2]
87100
mu = jnp.stack([mu1, mu2, mu3, mu4])
88101
return mu
@@ -94,36 +107,35 @@ def _diff(self, x, theta):
94107
A = theta[0]*x[3]*x[2] + theta[1]*(self._K-x[3])
95108
sigma11 = theta[2]*x[3] + theta[6]*x[0]
96109
sigma_max = jnp.where(0 < x[1]*(x[1]-1), x[1]*(x[1]-1), 0)
97-
sigma_max = x[1]*(x[1]-1)
110+
# sigma_max = x[1]*(x[1]-1)
98111
sigma22 = theta[7]*x[1] + 4*theta[5]*x[2] + \
99112
theta[3]*x[0] + 2*theta[4]*sigma_max
100113
sigma23 = -2*theta[5]*x[2] - theta[4]*sigma_max
101114
sigma33 = A + theta[5]*x[2] + 0.5*theta[4]*sigma_max
102115
sigma34 = A
103116
sigma44 = A
104117

105-
Sigma = jnp.array([[sigma11, 0, 0, 0],
106-
[0, sigma22, sigma23, 0],
107-
[0, sigma23, sigma33, sigma34],
108-
[0, 0, sigma34, sigma44]])
118+
Sigma = jnp.array([[sigma11, 0., 0., 0.],
119+
[0., sigma22, sigma23, 0.],
120+
[0., sigma23, sigma33, sigma34],
121+
[0., 0, sigma34, sigma44]])
109122

110123
return Sigma
111124

112125
def drift(self, x, theta):
113126
"""
114127
Calculates the SDE drift function on the log scale.
115128
"""
116-
x = jnp.exp(x)
117-
# K = self._K
129+
x = self._expit(x) + self._eps
118130
mu = self._drift(x, theta)
119131
Sigma = self._diff(x, theta)
120132

121-
#f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3] + 1/(K-x[3])])
122-
#f_pp = jnp.array([-1/x[0]/x[0], -1/x[1]/x[1], -1/x[2]/x[2], -1/x[3]/x[3] + 1/(K-x[3])/(K-x[3])])
123-
f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3]])
124-
f_pp = jnp.array(
125-
[-1/x[0]/x[0], -1/x[1]/x[1], -1/x[2]/x[2], -1/x[3]/x[3]]
126-
)
133+
f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3] + 1/(self._K-x[3])])
134+
f_pp = jnp.array([-1/x[0]/x[0], -1/x[1]/x[1], -1/x[2]/x[2], -1/x[3]/x[3] + 1/(self._K-x[3])/(self._K-x[3])])
135+
# f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3]])
136+
# f_pp = jnp.array(
137+
# [-1/x[0]/x[0], -1/x[1]/x[1], -1/x[2]/x[2], -1/x[3]/x[3]]
138+
# )
127139

128140
mu_trans = f_p * mu + 0.5 * f_pp * jnp.diag(Sigma)
129141
return mu_trans
@@ -132,12 +144,11 @@ def diff(self, x, theta):
132144
"""
133145
Calculates the SDE diffusion function on the log scale.
134146
"""
135-
x = jnp.exp(x)
136-
# K = self._K
147+
x = self._expit(x) + self._eps
137148
Sigma = self._diff(x, theta)
138149

139-
#f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3] + 1/(K-x[3])])
140-
f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3]])
150+
f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3] + 1/(self._K-x[3])])
151+
# f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3]])
141152
Sigma_trans = jnp.outer(f_p, f_p) * Sigma
142153

143154
return Sigma_trans
@@ -218,7 +229,7 @@ def pf_init(self, key, y_init, theta):
218229
upper=jnp.inf,
219230
shape=(self._n_state[1]-1,)
220231
))
221-
x_init = jnp.append(x_init, jnp.log(dna0))
232+
x_init = jnp.append(x_init, self._logit(dna0))
222233
logw = jnp.sum(jsp.stats.norm.logcdf(y_init/tau))
223234
#x_init = theta[12:16]
224235
#logw = -jnp.float_(0)
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""
2+
Prokaryotic auto-regulatory gene network Model.
3+
4+
The base model involves differential equations of the chemical reactions:
5+
6+
```
7+
DNA + P2 --> DNA_P2
8+
DNA_P2 --> DNA + P2
9+
DNA --> DNA + RNA
10+
RNA --> RNA + P
11+
P + P --> P2
12+
P2 --> P + P
13+
RNA --> 0
14+
P --> 0
15+
```
16+
These equations are associated with a parameter in `theta = (theta0, ..., theta7)`.
17+
The model is approximated by a SDE described in Golightly & Wilkinson (2005).
18+
A particular restriction on the chemical reactions is by the conservation law which implies that `DNA + DNA_P2 = K`.
19+
Thus the SDE model can be described in terms of `x_t = (RNA, P, P2, DNA)`.
20+
21+
Then assuming a standard form of the SDE, the base model can be written as
22+
```
23+
x_mt = x_{m, t-1} + mu_mt dt/m + Sigma_mt^{1/2} dt/m
24+
y_t ~ N( x_{m,mt}, diag(tau^2) )
25+
```
26+
27+
This model is on the regular scale.
28+
29+
30+
- Model parameters: `theta = (theta0, ... theta7, tau0, ... tau3)`.
31+
- Global constants: `dt` and `n_res`, i.e., `m`.
32+
- State dimensions: `n_state = (n_res, 4)`.
33+
- Measurement dimensions: `n_meas = 4`.
34+
35+
"""
36+
37+
import jax
38+
import jax.numpy as jnp
39+
import jax.scipy as jsp
40+
from jax import random
41+
from jax import lax
42+
from pfjax import sde as sde
43+
44+
# --- main functions -----------------------------------------------------------
45+
46+
47+
class RegPGNETModel(sde.SDEModel):
48+
49+
def __init__(self, dt, n_res, bootstrap=True):
50+
r"""
51+
Class constructor for the PGNET model.
52+
53+
Args:
54+
dt: SDE interobservation time.
55+
n_res: SDE resolution number. There are `n_res` latent variables per observation, equally spaced with interobservation time `dt/n_res`.
56+
bootstrap (bool): Flag indicating whether to use a Bootstrap particle filter or a bridge filter.
57+
58+
"""
59+
# creates "private" variables self._dt and self._n_res
60+
super().__init__(dt, n_res, diff_diag=False)
61+
self._n_state = (self._n_res, 4)
62+
self._K = 10
63+
self._eps = 1e-10
64+
self._bootstrap = bootstrap
65+
66+
def drift(self, x, theta):
67+
"""
68+
Calculate the drift on the original scale.
69+
"""
70+
mu1 = theta[2]*x[3] - theta[6]*x[0]
71+
sigma_max = jnp.where(0 < x[1]*(x[1]-1), x[1]*(x[1]-1), 0)
72+
# sigma_max = x[1]*(x[1]-1)
73+
mu2 = 2*theta[5]*x[2] - theta[7]*x[1] + \
74+
theta[3]*x[0] - theta[4]*sigma_max
75+
mu3 = theta[1]*(self._K-x[3]) - theta[0]*x[3]*x[2] - \
76+
theta[5]*x[2] + 0.5*theta[4]*sigma_max
77+
mu4 = theta[1]*(self._K-x[3]) - theta[0]*x[3]*x[2]
78+
mu = jnp.stack([mu1, mu2, mu3, mu4])
79+
return mu
80+
81+
def diff(self, x, theta):
82+
"""
83+
Calculate the diffusion matrix on the original scale.
84+
"""
85+
A = theta[0]*x[3]*x[2] + theta[1]*(self._K-x[3])
86+
sigma11 = theta[2]*x[3] + theta[6]*x[0]
87+
sigma_max = jnp.where(0 < x[1]*(x[1]-1), x[1]*(x[1]-1), 0)
88+
# sigma_max = x[1]*(x[1]-1)
89+
sigma22 = theta[7]*x[1] + 4*theta[5]*x[2] + \
90+
theta[3]*x[0] + 2*theta[4]*sigma_max
91+
sigma23 = -2*theta[5]*x[2] - theta[4]*sigma_max
92+
sigma33 = A + theta[5]*x[2] + 0.5*theta[4]*sigma_max
93+
sigma34 = A
94+
sigma44 = A
95+
96+
Sigma = jnp.array([[sigma11, 0., 0., 0.],
97+
[0., sigma22, sigma23, 0.],
98+
[0., sigma23, sigma33, sigma34],
99+
[0., 0, sigma34, sigma44]])
100+
101+
return Sigma
102+
103+
def meas_lpdf(self, y_curr, x_curr, theta):
104+
"""
105+
Log-density of `p(y_curr | x_curr, theta)`.
106+
107+
Args:
108+
y_curr: Measurement variable at current time `t`.
109+
x_curr: State variable at current time `t`.
110+
theta: Parameter value.
111+
112+
Returns
113+
The log-density of `p(y_curr | x_curr, theta)`.
114+
"""
115+
tau = theta[8:12]
116+
return jnp.sum(
117+
jsp.stats.norm.logpdf(y_curr, loc=x_curr[-1], scale=tau)
118+
)
119+
120+
def meas_sample(self, key, x_curr, theta):
121+
"""
122+
Sample from `p(y_curr | x_curr, theta)`.
123+
124+
Args:
125+
x_curr: State variable at current time `t`.
126+
theta: Parameter value.
127+
key: PRNG key.
128+
129+
Returns:
130+
Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
131+
"""
132+
tau = theta[8:12]
133+
return x_curr[-1] + tau * random.normal(key, (self._n_state[1],))
134+
135+
def pf_init(self, key, y_init, theta):
136+
"""
137+
Particle filter calculation for `x_init`.
138+
139+
Samples from an importance sampling proposal distribution
140+
```
141+
x_init ~ q(x_init) = q(x_init | y_init, theta)
142+
```
143+
and calculates the log weight
144+
```
145+
logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)
146+
```
147+
148+
**FIXME:** Explain what the proposal is and why it gives `logw = 0`.
149+
150+
In fact, if you think about it hard enough then it's not actually a perfect proposal...
151+
152+
Args:
153+
y_init: Measurement variable at initial time `t = 0`.
154+
theta: Parameter value.
155+
key: PRNG key.
156+
157+
Returns:
158+
- x_init: A sample from the proposal distribution for `x_init`.
159+
- logw: The log-weight of `x_init`.
160+
"""
161+
tau = theta[8:12]
162+
# key, subkey = random.split(key)
163+
# x_init = jnp.log(y_init +
164+
# tau * random.normal(subkey, (self.n_state[1],)))
165+
# return \
166+
# jnp.append(jnp.zeros((self.n_res-1,) + x_init.shape),
167+
# jnp.expand_dims(x_init, axis=0), axis=0), \
168+
# jnp.zeros(())
169+
170+
key, subkey = random.split(key)
171+
x_init123 = y_init[:3] + tau[:3] * random.truncated_normal(
172+
subkey,
173+
lower=-y_init[:3]/tau[:3],
174+
upper=jnp.inf,
175+
shape=(self._n_state[1]-1,)
176+
)
177+
178+
x_init4 = y_init[3] + tau[3] * random.truncated_normal(
179+
subkey,
180+
lower=-y_init[3]/tau[3],
181+
upper=(self._K - y_init[3])/tau[3],
182+
shape=(1,)
183+
)
184+
x_init = jnp.append(x_init123, x_init4)
185+
logw = jnp.sum(jsp.stats.norm.logcdf(y_init/tau))
186+
187+
return \
188+
jnp.append(jnp.zeros((self._n_res-1,) + x_init.shape),
189+
jnp.expand_dims(x_init, axis=0), axis=0), \
190+
logw
191+
192+
def pf_step(self, key, x_prev, y_curr, theta):
193+
"""
194+
Choose between bootstrap filter and bridge proposal.
195+
196+
Args:
197+
x_prev: State variable at previous time `t-1`.
198+
y_curr: Measurement variable at current time `t`.
199+
theta: Parameter value.
200+
key: PRNG key.
201+
202+
Returns:
203+
- x_curr: Sample of the state variable at current time `t`: `x_curr ~ q(x_curr)`.
204+
- logw: The log-weight of `x_curr`.
205+
"""
206+
if self._bootstrap:
207+
x_curr, logw = super().pf_step(key, x_prev, y_curr, theta)
208+
else:
209+
omega = theta[8:12]**2
210+
211+
x_curr, logw = self.bridge_prop(
212+
key, x_prev, y_curr, theta,
213+
y_curr, jnp.eye(4), jnp.diag(omega)
214+
)
215+
return x_curr, logw
216+
217+
def is_valid(self, x, theta):
218+
"""
219+
Checks whether SDE observations are valid.
220+
221+
Args:
222+
x: SDE variables. A vector of size `n_dims`.
223+
theta: Parameter value.
224+
225+
Returns:
226+
Whether or not `x>=0`.
227+
"""
228+
return (x >= 0) & (x[3] <= self._K)

0 commit comments

Comments
 (0)