55)
66
77from bayesflow .types import Tensor
8- from bayesflow .utils import filter_kwargs , split_arrays , squeeze_inner_estimates_dict
8+ from bayesflow .utils import filter_kwargs , split_arrays , squeeze_inner_estimates_dict , logging
99from .continuous_approximator import ContinuousApproximator
1010
1111
@@ -14,8 +14,9 @@ class PointApproximator(ContinuousApproximator):
1414 """
1515 A workflow for fast amortized point estimation of a conditional distribution.
1616
17- The distribution is approximated by point estimators, parameterized by a feed-forward `PointInferenceNetwork`.
18- Conditions can be compressed by an optional `SummaryNetwork` or used directly as input to the inference network.
17+ The distribution is approximated by point estimators, parameterized by a feed-forward
18+ :py:class:`~bayesflow.networks.PointInferenceNetwork`. Conditions can be compressed by an optional summary network
19+ (inheriting from :py:class:`~bayesflow.networks.SummaryNetwork`) or used directly as input to the inference network.
1920 """
2021
2122 def estimate (
@@ -89,7 +90,7 @@ def sample(
8990 for the sampling process.
9091 split : bool, optional
9192 If True, the sampled arrays are split along the last axis, by default False.
92- Currently not supported for `PointApproximator`.
93+ Currently not supported for :py:class: `PointApproximator` .
9394 **kwargs
9495 Additional keyword arguments passed to underlying processing functions.
9596
@@ -111,14 +112,50 @@ def sample(
111112 if split :
112113 raise NotImplementedError ("split=True is currently not supported for `PointApproximator`." )
113114 samples = split_arrays (samples , axis = - 1 )
114- # Squeeze samples if there's only one key-value pair.
115- samples = self ._squeeze_samples (samples )
115+ # Squeeze sample dictionary if there's only one key-value pair.
116+ samples = self ._squeeze_parametric_score_major_dict (samples )
116117
117118 return samples
118119
120+ def log_prob (
121+ self ,
122+ * ,
123+ data : dict [str , np .ndarray ],
124+ ** kwargs ,
125+ ) -> np .ndarray | dict [str , np .ndarray ]:
126+ """
127+ Computes the log-probability of given data under the parametric distribution(s) for given input conditions.
128+
129+ Parameters
130+ ----------
131+ data : dict[str, np.ndarray]
132+ A dictionary mapping variable names to arrays representing the inference conditions and variables.
133+ **kwargs
134+ Additional keyword arguments passed to underlying processing functions.
135+
136+ Returns
137+ -------
138+ log_prob : np.ndarray or dict[str, np.ndarray]
139+ Log-probabilities of the distribution
140+ `p(inference_variables | inference_conditions, h(summary_conditions))` for all parametric scoring rules.
141+
142+ If only one parametric score is available, output is an array of log-probabilities.
143+
144+ Output is a dictionary if multiple parametric scores are available.
145+ Then, each key is the name of a score and values are corresponding log-probabilities.
146+
147+ Log-probabilities have shape (num_datasets,).
148+ """
149+ log_prob = super ().log_prob (data = data , ** kwargs )
150+ # Squeeze log probabilities dictionary if there's only one key-value pair.
151+ log_prob = self ._squeeze_parametric_score_major_dict (log_prob )
152+
153+ return log_prob
154+
119155 def _prepare_conditions (self , conditions : dict [str , np .ndarray ], ** kwargs ) -> dict [str , Tensor ]:
120156 """Adapts and converts the conditions to tensors."""
121157 conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
158+ conditions .pop ("inference_variables" , None )
122159 return keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
123160
124161 def _apply_inverse_adapter_to_estimates (
@@ -130,6 +167,12 @@ def _apply_inverse_adapter_to_estimates(
130167 for score_key , score_val in estimates .items ():
131168 processed [score_key ] = {}
132169 for head_key , estimate in score_val .items ():
170+ if head_key in self .inference_network .scores [score_key ].NOT_TRANSFORMING_LIKE_VECTOR_WARNING :
171+ logging .warning (
172+ f"Estimate '{ score_key } .{ head_key } ' is marked to not transform like a vector. "
173+ f"It was treated like a vector by the adapter. Handle '{ head_key } ' estimates with care."
174+ )
175+
133176 adapted = self .adapter (
134177 {"inference_variables" : estimate },
135178 inverse = True ,
@@ -180,8 +223,10 @@ def _squeeze_estimates(
180223 }
181224 return squeezed
182225
183- def _squeeze_samples (self , samples : dict [str , np .ndarray ]) -> np .ndarray or dict [str , np .ndarray ]:
184- """Squeezes the samples dictionary to just the value if there is only one key-value pair."""
226+ def _squeeze_parametric_score_major_dict (
227+ self , samples : dict [str , np .ndarray ]
228+ ) -> np .ndarray or dict [str , np .ndarray ]:
229+ """Squeezes the dictionary to just the value if there is only one key-value pair."""
185230 if len (samples ) == 1 :
186231 return next (iter (samples .values ())) # Extract and return the only item's value
187232 return samples
0 commit comments