Skip to content

Commit 043beed

Browse files
authored
Split posterior precision update into two branches (#263)
* split posterior update of precision into two branches * add tests * fix api docs
1 parent 80ab968 commit 043beed

File tree

4 files changed

+145
-37
lines changed

4 files changed

+145
-37
lines changed

docs/source/api.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ Prediction error steps
8989
Compute the value and volatility prediction errors of a given node. The prediction error can only be computed after the posterior update (or observation) of a given node.
9090

9191
Binary state nodes
92-
^^^^^^^^^^^^^^^^^^
92+
------------------
9393

9494
.. currentmodule:: pyhgf.updates.prediction_error.binary
9595

@@ -100,7 +100,7 @@ Binary state nodes
100100
binary_finite_state_node_prediction_error
101101

102102
Categorical state nodes
103-
^^^^^^^^^^^^^^^^^^^^^^^
103+
-----------------------
104104

105105
.. currentmodule:: pyhgf.updates.prediction_error.categorical
106106

@@ -110,7 +110,7 @@ Categorical state nodes
110110
categorical_state_prediction_error
111111

112112
Continuous state nodes
113-
^^^^^^^^^^^^^^^^^^^^^^
113+
----------------------
114114

115115
.. currentmodule:: pyhgf.updates.prediction_error.continuous
116116

@@ -122,7 +122,7 @@ Continuous state nodes
122122
continuous_node_prediction_error
123123

124124
Dirichlet state nodes
125-
^^^^^^^^^^^^^^^^^^^^^
125+
---------------------
126126

127127
.. currentmodule:: pyhgf.updates.prediction_error.dirichlet
128128

@@ -137,7 +137,7 @@ Dirichlet state nodes
137137
clusters_likelihood
138138

139139
Exponential family
140-
^^^^^^^^^^^^^^^^^^
140+
------------------
141141

142142
.. currentmodule:: pyhgf.updates.prediction_error.exponential
143143

pyhgf/updates/posterior/continuous/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33

44
__all__ = [
55
"continuous_node_posterior_update_ehgf",
6-
"continuous_node_posterior_update_unbounded",
76
"continuous_node_posterior_update",
87
]

pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node.py

Lines changed: 94 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import jax.numpy as jnp
77
from jax import grad, jit
8+
from jax.lax import cond
9+
from jax.tree_util import Partial
810

911
from pyhgf.typing import Edges
1012

@@ -31,7 +33,7 @@ def posterior_update_precision_continuous_node(
3133
3234
Where :math:`\kappa_j` is the volatility coupling strength between the child node
3335
and the state node and :math:`\delta_j^{(k)}` is the value prediction error that
34-
was computed before hand by
36+
was computed beforehand by
3537
:py:func:`pyhgf.updates.prediction_errors.continuous.continuous_node_value_prediction_error`.
3638
3739
For non-linear value coupling:
@@ -80,8 +82,9 @@ def posterior_update_precision_continuous_node(
8082
The attributes of the probabilistic nodes.
8183
edges :
8284
The edges of the probabilistic nodes as a tuple of
83-
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
84-
For each node, the index list value and volatility parents and children.
85+
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
86+
of nodes. For each node, the index lists the value and volatility parents and
87+
children.
8588
node_idx :
8689
Pointer to the value parent node that will be updated.
8790
time_step :
@@ -108,6 +111,60 @@ def posterior_update_precision_continuous_node(
108111
Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1).
109112
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
110113
114+
"""
115+
# ----------------------------------------------------------------------------------
116+
# Decide which update to use depending on the presence of observed value in the
117+
# children nodes. If no values were observed, the precision should increase
118+
# as a function of time using the function precision_missing_values(). Otherwise,
119+
# we use regular HGF updates for value and volatility couplings.
120+
# ----------------------------------------------------------------------------------
121+
122+
# For all children, get the `observed` flag - if all these values are 0.0, the node
123+
# has not received any observations and we should call precision_missing_values()
124+
observations = []
125+
if edges[node_idx].value_children is not None:
126+
for children_idx in edges[node_idx].value_children: # type: ignore
127+
observations.append(attributes[children_idx]["observed"])
128+
if edges[node_idx].volatility_children is not None:
129+
for children_idx in edges[node_idx].volatility_children: # type: ignore
130+
observations.append(attributes[children_idx]["observed"])
131+
observations = jnp.any(jnp.array(observations))
132+
133+
posterior_precision = cond(
134+
observations,
135+
Partial(precision_update, edges=edges, node_idx=node_idx),
136+
Partial(precision_update_missing_values, edges=edges, node_idx=node_idx),
137+
attributes,
138+
)
139+
140+
return posterior_precision
141+
142+
143+
@partial(jit, static_argnames=("edges", "node_idx"))
144+
def precision_update(attributes: Dict, edges: Edges, node_idx: int) -> float:
145+
"""Compute new precision in the case of observed values.
146+
147+
Parameters
148+
----------
149+
attributes :
150+
The attributes of the probabilistic nodes.
151+
edges :
152+
The edges of the probabilistic nodes as a tuple of
153+
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
154+
of nodes. For each node, the index lists the value and volatility parents and
155+
children.
156+
node_idx :
157+
Pointer to the value parent node that will be updated.
158+
time_step :
159+
The time elapsed between this observation and the previous one.
160+
161+
Returns
162+
-------
163+
posterior_precision :
164+
The new posterior precision when at least one of the children has
165+
observed a new value. We then use the regular HGF update for volatility
166+
coupling.
167+
111168
"""
112169
# sum the prediction errors from both value and volatility coupling
113170
precision_weigthed_prediction_error = 0.0
@@ -177,13 +234,41 @@ def posterior_update_precision_continuous_node(
177234
)
178235

179236
# ensure the new precision is greater than 0
180-
observed_posterior_precision = jnp.where(
237+
posterior_precision = jnp.where(
181238
posterior_precision > 1e-128, posterior_precision, jnp.nan
182239
)
183240

184-
# additionnal steps for unobserved values
185-
# ---------------------------------------
241+
return posterior_precision
186242

243+
244+
@partial(jit, static_argnames=("edges", "node_idx"))
245+
def precision_update_missing_values(
246+
attributes: Dict, edges: Edges, node_idx: int
247+
) -> float:
248+
"""Compute new precision in the case of missing observations.
249+
250+
Parameters
251+
----------
252+
attributes :
253+
The attributes of the probabilistic nodes.
254+
edges :
255+
The edges of the probabilistic nodes as a tuple of
256+
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
257+
of nodes. For each node, the index lists the value and volatility parents and
258+
children.
259+
node_idx :
260+
Pointer to the value parent node that will be updated.
261+
time_step :
262+
The time elapsed between this observation and the previous one.
263+
264+
Returns
265+
-------
266+
posterior_precision_missing_values :
267+
The new posterior precision in the case of missing values in all child nodes.
268+
The new precision decreases proportionally to the time elapsed, accounting for
269+
the influence of volatility parents.
270+
271+
"""
187272
# List the node's volatility parents
188273
volatility_parents_idxs = edges[node_idx].volatility_parents
189274

@@ -201,29 +286,13 @@ def posterior_update_precision_continuous_node(
201286
volatility_coupling * attributes[volatility_parents_idx]["mean"]
202287
)
203288

204-
# compute the predicted_volatility from the total volatility
289+
# compute the new predicted_volatility from the total volatility
205290
time_step = attributes[-1]["time_step"]
206291
predicted_volatility = time_step * jnp.exp(total_volatility)
207292

208293
# Estimate the new precision for the continuous state node
209-
unobserved_posterior_precision = 1 / (
294+
posterior_precision_missing_values = 1 / (
210295
(1 / attributes[node_idx]["precision"]) + predicted_volatility
211296
)
212297

213-
# for all children, look at the values of VAPE
214-
# if all these values are NaNs, the node has not received observations
215-
observations = []
216-
if edges[node_idx].value_children is not None:
217-
for children_idx in edges[node_idx].value_children: # type: ignore
218-
observations.append(attributes[children_idx]["observed"])
219-
if edges[node_idx].volatility_children is not None:
220-
for children_idx in edges[node_idx].volatility_children: # type: ignore
221-
observations.append(attributes[children_idx]["observed"])
222-
observations = jnp.any(jnp.array(observations))
223-
224-
posterior_precision = (
225-
unobserved_posterior_precision * (1 - observations) # type: ignore
226-
+ observed_posterior_precision * observations
227-
)
228-
229-
return posterior_precision
298+
return posterior_precision_missing_values

tests/test_updates/posterior/continuous.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
22

3+
import jax.numpy as jnp
4+
35
from pyhgf.model import Network
46
from pyhgf.updates.posterior.continuous import (
57
continuous_node_posterior_update,
68
continuous_node_posterior_update_ehgf,
7-
continuous_node_posterior_update_unbounded,
89
)
910

1011

@@ -20,17 +21,56 @@ def test_continuous_posterior_updates():
2021

2122
# Standard HGF updates -------------------------------------------------------------
2223
# ----------------------------------------------------------------------------------
24+
25+
# value update
26+
attributes, edges, _ = network.get_network()
27+
attributes[0]["temp"]["value_prediction_error"] = 1.0357
28+
attributes[0]["mean"] = 1.0357
29+
30+
new_attributes = continuous_node_posterior_update(
31+
attributes=attributes, node_idx=1, edges=edges
32+
)
33+
assert jnp.isclose(new_attributes[1]["mean"], 0.51785)
34+
35+
# volatility update
2336
attributes, edges, _ = network.get_network()
24-
_ = continuous_node_posterior_update(attributes=attributes, node_idx=2, edges=edges)
37+
attributes[1]["temp"]["effective_precision"] = 0.01798621006309986
38+
attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467
39+
attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894
40+
attributes[1]["expected_precision"] = 0.9820137619972229
41+
attributes[1]["mean"] = 0.5225493907928467
42+
attributes[1]["precision"] = 1.9820137023925781
43+
44+
new_attributes = continuous_node_posterior_update(
45+
attributes=attributes, node_idx=2, edges=edges
46+
)
47+
assert jnp.isclose(new_attributes[1]["mean"], -0.0021212)
48+
assert jnp.isclose(new_attributes[1]["precision"], 1.0022112)
2549

2650
# eHGF updates ---------------------------------------------------------------------
2751
# ----------------------------------------------------------------------------------
28-
_ = continuous_node_posterior_update_ehgf(
52+
53+
# value update
54+
attributes, edges, _ = network.get_network()
55+
attributes[0]["temp"]["value_prediction_error"] = 1.0357
56+
attributes[0]["mean"] = 1.0357
57+
58+
new_attributes = continuous_node_posterior_update_ehgf(
2959
attributes=attributes, node_idx=2, edges=edges
3060
)
61+
assert jnp.isclose(new_attributes[1]["mean"], 0.51785)
3162

32-
# unbounded updates ----------------------------------------------------------------
33-
# ----------------------------------------------------------------------------------
34-
_ = continuous_node_posterior_update_unbounded(
63+
# volatility update
64+
attributes, edges, _ = network.get_network()
65+
attributes[1]["temp"]["effective_precision"] = 0.01798621006309986
66+
attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467
67+
attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894
68+
attributes[1]["expected_precision"] = 0.9820137619972229
69+
attributes[1]["mean"] = 0.5225493907928467
70+
attributes[1]["precision"] = 1.9820137023925781
71+
72+
new_attributes = continuous_node_posterior_update_ehgf(
3573
attributes=attributes, node_idx=2, edges=edges
3674
)
75+
assert jnp.isclose(new_attributes[1]["mean"], -0.00212589)
76+
assert jnp.isclose(new_attributes[1]["precision"], 1.0022112)

0 commit comments

Comments
 (0)