Skip to content

Commit ffea310

Browse files
committed
Experimented with pgnet model with no DNA
1 parent a5b5db8 commit ffea310

File tree

4 files changed

+415
-566
lines changed

4 files changed

+415
-566
lines changed

examples/pgnet.ipynb

Lines changed: 380 additions & 559 deletions
Large diffs are not rendered by default.

src/pfjax/experimental/models/pgnet_model_no_DNA.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,17 @@ def pf_step(self, key, x_prev, y_curr, theta):
263263
jnp.log(y_curr), jnp.eye(4)[:-1, :], jnp.diag(omega)
264264
)
265265
return x_curr, logw
266+
267+
def prop_lpdf(self, x_curr, x_prev, y_curr, theta):
268+
r"""
269+
Calculate the log-density of the proposal distribution `q(x_curr) = q(x_curr | x_prev, y_curr, theta)`.
270+
271+
In this case we have a bootstrap filter, so `q(x_curr) = p(x_curr | x_prev, theta)`.
272+
273+
Args:
274+
x_curr: State variable at current time `t`.
275+
x_prev: State variable at previous time `t-1`.
276+
y_curr: Measurement variable at current time `t`.
277+
theta: Parameter value.
278+
"""
279+
return super().state_lpdf(x_curr=x_curr, x_prev=x_prev, theta=theta)

src/pfjax/experimental/models/pgnet_model_reg_no_DNA.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,4 +222,18 @@ def is_valid(self, x, theta):
222222
Returns:
223223
Whether or not `x>=0`.
224224
"""
225-
return (x >= 0) & (x[3] <= self._K)
225+
return jnp.alltrue(x >= 0) & (x[3] <= self._K)
226+
227+
def prop_lpdf(self, x_curr, x_prev, y_curr, theta):
228+
r"""
229+
Calculate the log-density of the proposal distribution `q(x_curr) = q(x_curr | x_prev, y_curr, theta)`.
230+
231+
In this case we have a bootstrap filter, so `q(x_curr) = p(x_curr | x_prev, theta)`.
232+
233+
Args:
234+
x_curr: State variable at current time `t`.
235+
x_prev: State variable at previous time `t-1`.
236+
y_curr: Measurement variable at current time `t`.
237+
theta: Parameter value.
238+
"""
239+
return super().state_lpdf(x_curr=x_curr, x_prev=x_prev, theta=theta)

src/pfjax/sde.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,12 @@ def scan_fun(carry, n):
436436
theta=theta
437437
)
438438
logw = logw + self.meas_lpdf(y_curr, x_prop, theta) - last["lp"]
439-
# logw = lax.cond(
440-
# self.is_valid_state(x_prop, theta),
441-
# lambda _x: logw,
442-
# lambda _x: -jnp.inf,
443-
# 0.0
444-
# )
439+
logw = lax.cond(
440+
self.is_valid_state(x_prop, theta),
441+
lambda _x: logw,
442+
lambda _x: -jnp.inf,
443+
0.0
444+
)
445445
return x_prop, logw
446446

447447
def bridge_prop_for(self, key, x_prev, y_curr, theta, Y, A, Omega):

0 commit comments

Comments
 (0)