55
66import jax .numpy as jnp
77from jax import grad , jit
8+ from jax .lax import cond
9+ from jax .tree_util import Partial
810
911from pyhgf .typing import Edges
1012
@@ -31,7 +33,7 @@ def posterior_update_precision_continuous_node(
3133
3234 Where :math:`\kappa_j` is the volatility coupling strength between the child node
3335 and the state node and :math:`\delta_j^{(k)}` is the value prediction error that
34- was computed before hand by
36+ was computed beforehand by
3537 :py:func:`pyhgf.updates.prediction_errors.continuous.continuous_node_value_prediction_error`.
3638
3739 For non-linear value coupling:
@@ -80,8 +82,9 @@ def posterior_update_precision_continuous_node(
8082 The attributes of the probabilistic nodes.
8183 edges :
8284 The edges of the probabilistic nodes as a tuple of
83- :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
84- For each node, the index list value and volatility parents and children.
85+ :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
86+ of nodes. For each node, the index lists the value and volatility parents and
87+ children.
8588 node_idx :
8689 Pointer to the value parent node that will be updated.
8790 time_step :
@@ -108,6 +111,60 @@ def posterior_update_precision_continuous_node(
108111 Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1).
109112 arXiv. https://doi.org/10.48550/ARXIV.2305.10937
110113
114+ """
115+ # ----------------------------------------------------------------------------------
116+ # Decide which update to use depending on the presence of observed value in the
117+ # children nodes. If no values were observed, the precision should increase
118+ # as a function of time using the function precision_missing_values(). Otherwise,
119+ # we use regular HGF updates for value and volatility couplings.
120+ # ----------------------------------------------------------------------------------
121+
122+ # For all children, get the `observed` flag - if all these values are 0.0, the node
123+ # has not received any observations and we should call precision_missing_values()
124+ observations = []
125+ if edges [node_idx ].value_children is not None :
126+ for children_idx in edges [node_idx ].value_children : # type: ignore
127+ observations .append (attributes [children_idx ]["observed" ])
128+ if edges [node_idx ].volatility_children is not None :
129+ for children_idx in edges [node_idx ].volatility_children : # type: ignore
130+ observations .append (attributes [children_idx ]["observed" ])
131+ observations = jnp .any (jnp .array (observations ))
132+
133+ posterior_precision = cond (
134+ observations ,
135+ Partial (precision_update , edges = edges , node_idx = node_idx ),
136+ Partial (precision_update_missing_values , edges = edges , node_idx = node_idx ),
137+ attributes ,
138+ )
139+
140+ return posterior_precision
141+
142+
143+ @partial (jit , static_argnames = ("edges" , "node_idx" ))
144+ def precision_update (attributes : Dict , edges : Edges , node_idx : int ) -> float :
145+ """Compute new precision in the case of observed values.
146+
147+ Parameters
148+ ----------
149+ attributes :
150+ The attributes of the probabilistic nodes.
151+ edges :
152+ The edges of the probabilistic nodes as a tuple of
153+ :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
154+ of nodes. For each node, the index lists the value and volatility parents and
155+ children.
156+ node_idx :
157+ Pointer to the value parent node that will be updated.
158+ time_step :
159+ The time elapsed between this observation and the previous one.
160+
161+ Returns
162+ -------
163+ posterior_precision :
164+ The new posterior precision when at least one of the children has
165+ observed a new value. We then use the regular HGF update for volatility
166+ coupling.
167+
111168 """
112169 # sum the prediction errors from both value and volatility coupling
113170 precision_weigthed_prediction_error = 0.0
@@ -177,13 +234,41 @@ def posterior_update_precision_continuous_node(
177234 )
178235
179236 # ensure the new precision is greater than 0
180- observed_posterior_precision = jnp .where (
237+ posterior_precision = jnp .where (
181238 posterior_precision > 1e-128 , posterior_precision , jnp .nan
182239 )
183240
184- # additionnal steps for unobserved values
185- # ---------------------------------------
241+ return posterior_precision
186242
243+
244+ @partial (jit , static_argnames = ("edges" , "node_idx" ))
245+ def precision_update_missing_values (
246+ attributes : Dict , edges : Edges , node_idx : int
247+ ) -> float :
248+ """Compute new precision in the case of missing observations.
249+
250+ Parameters
251+ ----------
252+ attributes :
253+ The attributes of the probabilistic nodes.
254+ edges :
255+ The edges of the probabilistic nodes as a tuple of
256+ :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
257+ of nodes. For each node, the index lists the value and volatility parents and
258+ children.
259+ node_idx :
260+ Pointer to the value parent node that will be updated.
261+ time_step :
262+ The time elapsed between this observation and the previous one.
263+
264+ Returns
265+ -------
266+ posterior_precision_missing_values :
267+ The new posterior precision in the case of missing values in all child nodes.
268+ The new precision decreases proportionally to the time elapsed, accounting for
269+ the influence of volatility parents.
270+
271+ """
187272 # List the node's volatility parents
188273 volatility_parents_idxs = edges [node_idx ].volatility_parents
189274
@@ -201,29 +286,13 @@ def posterior_update_precision_continuous_node(
201286 volatility_coupling * attributes [volatility_parents_idx ]["mean" ]
202287 )
203288
204- # compute the predicted_volatility from the total volatility
289+ # compute the new predicted_volatility from the total volatility
205290 time_step = attributes [- 1 ]["time_step" ]
206291 predicted_volatility = time_step * jnp .exp (total_volatility )
207292
208293 # Estimate the new precision for the continuous state node
209- unobserved_posterior_precision = 1 / (
294+ posterior_precision_missing_values = 1 / (
210295 (1 / attributes [node_idx ]["precision" ]) + predicted_volatility
211296 )
212297
213- # for all children, look at the values of VAPE
214- # if all these values are NaNs, the node has not received observations
215- observations = []
216- if edges [node_idx ].value_children is not None :
217- for children_idx in edges [node_idx ].value_children : # type: ignore
218- observations .append (attributes [children_idx ]["observed" ])
219- if edges [node_idx ].volatility_children is not None :
220- for children_idx in edges [node_idx ].volatility_children : # type: ignore
221- observations .append (attributes [children_idx ]["observed" ])
222- observations = jnp .any (jnp .array (observations ))
223-
224- posterior_precision = (
225- unobserved_posterior_precision * (1 - observations ) # type: ignore
226- + observed_posterior_precision * observations
227- )
228-
229- return posterior_precision
298+ return posterior_precision_missing_values
0 commit comments