Skip to content

Commit 4251bbd

Browse files
committed
Estimate method in BasicWorkflow and docs for PointApproximator
1 parent 97c381d commit 4251bbd

File tree

3 files changed

+58
-18
lines changed

3 files changed

+58
-18
lines changed

bayesflow/approximators/point_approximator.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,17 @@ def estimate(
2323
conditions: dict[str, np.ndarray],
2424
split: bool = False,
2525
**kwargs,
26-
) -> dict[str, dict[str, np.ndarray]]:
26+
) -> dict[str, dict[str, np.ndarray | dict[str, np.ndarray]]]:
2727
"""
28-
Provides point estimates based on provided conditions (e.g., observables).
28+
Estimates point summaries of inference variables based on specified conditions.
2929
3030
This method processes input conditions, computes estimates, applies necessary adapter transformations,
3131
and optionally splits the resulting arrays along the last axis.
3232
3333
Parameters
3434
----------
3535
conditions : dict[str, np.ndarray]
36-
A dictionary mapping variable names to NumPy arrays representing the conditions
36+
A dictionary mapping variable names to arrays representing the conditions
3737
for the estimation process.
3838
split : bool, optional
3939
If True, the estimated arrays are split along the last axis, by default False.
@@ -42,9 +42,15 @@ def estimate(
4242
4343
Returns
4444
-------
45-
dict[str, dict[str, np.ndarray]]
46-
A nested dictionary where the top-level keys correspond to original variable names,
47-
and values contain dictionaries mapping estimation results to NumPy arrays.
45+
estimates : dict[str, dict[str, np.ndarray or dict[str, np.ndarray]]]
46+
The estimates of inference variables in a nested dictionary.
47+
48+
1. Each first-level key is the name of an inference variable.
49+
2. Each second-level key is the name of a scoring rule.
50+
3. (If the scoring rule comprises multiple estimators, each third-level key is the name of an estimator.)
51+
52+
Each estimator output (i.e., dictionary value that is not itself a dictionary) is an array
53+
of shape (num_datasets, point_estimate_size, variable_block_size).
4854
"""
4955

5056
conditions = self._prepare_conditions(conditions, **kwargs)
@@ -67,39 +73,43 @@ def sample(
6773
conditions: dict[str, np.ndarray],
6874
split: bool = False,
6975
**kwargs,
70-
) -> dict[str, np.ndarray]:
76+
) -> dict[str, dict[str, np.ndarray]]:
7177
"""
72-
Generate samples from point estimates based on provided conditions. These samples
73-
will generally not correspond to samples from the fully Bayesian posterior, since
74-
they will assume some parametric form (e.g., Gaussian in the case of mean score).
78+
Draws samples from a parametric distribution based on point estimates for given input conditions.
7579
76-
This method draws a specified number of samples according to the given conditions,
77-
applies necessary transformations, and optionally splits the resulting arrays along the last axis.
80+
These samples will generally not correspond to samples from the fully Bayesian posterior, since
81+
they will assume some parametric form (e.g., multivariate normal when using the MultivariateNormalScore).
7882
7983
Parameters
8084
----------
8185
num_samples : int
8286
The number of samples to generate.
8387
conditions : dict[str, np.ndarray]
84-
A dictionary mapping variable names to NumPy arrays representing the conditions
88+
A dictionary mapping variable names to arrays representing the conditions
8589
for the sampling process.
8690
split : bool, optional
8791
If True, the sampled arrays are split along the last axis, by default False.
92+
Currently not supported for `PointApproximator`.
8893
**kwargs
8994
Additional keyword arguments passed to underlying processing functions.
9095
9196
Returns
9297
-------
93-
dict[str, np.ndarray]
94-
A dictionary where keys correspond to variable names and values are NumPy arrays
95-
containing the generated samples.
96-
"""
98+
samples : dict[str, np.ndarray or dict[str, np.ndarray]]
99+
Samples for all inference variables and all parametric scoring rules in a nested dictionary.
100+
101+
1. Each first-level key is the name of an inference variable.
102+
2. (If there are multiple parametric scores, each second-level key is the name of such a score.)
97103
104+
Each output (i.e., dictionary value that is not itself a dictionary) is an array
105+
of shape (num_datasets, num_samples, variable_block_size).
106+
"""
98107
conditions = self._prepare_conditions(conditions, **kwargs)
99108
samples = self._sample(num_samples, **conditions, **kwargs)
100109
samples = self._apply_inverse_adapter_to_samples(samples, **kwargs)
101110
# Optionally split the arrays along the last axis.
102111
if split:
112+
raise NotImplementedError("split=True is currently not supported for `PointApproximator`.")
103113
samples = split_arrays(samples, axis=-1)
104114
# Squeeze samples if there's only one key-value pair.
105115
samples = self._squeeze_samples(samples)

bayesflow/networks/point_inference_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def call(
128128
conditions: Tensor = None,
129129
training: bool = False,
130130
**kwargs,
131-
) -> dict[str, Tensor]:
131+
) -> dict[str, dict[str, Tensor]]:
132132
if xz is None and not self.built:
133133
raise ValueError("Cannot build inference network without inference variables.")
134134
if conditions is None: # unconditional estimation uses a fixed input vector

bayesflow/workflows/basic_workflow.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,36 @@ def sample(
290290
"""
291291
return self.approximator.sample(num_samples=num_samples, conditions=conditions, **kwargs)
292292

293+
def estimate(
294+
self,
295+
*,
296+
conditions: dict[str, np.ndarray],
297+
**kwargs,
298+
) -> dict[str, dict[str, np.ndarray | dict[str, np.ndarray]]]:
299+
"""
300+
Estimates point summaries of inference variables based on specified conditions.
301+
302+
Parameters
303+
----------
304+
conditions : dict[str, np.ndarray]
305+
A dictionary mapping variable names to arrays representing the conditions for the estimation process.
306+
**kwargs
307+
Additional keyword arguments passed to underlying processing functions.
308+
309+
Returns
310+
-------
311+
estimates : dict[str, dict[str, np.ndarray or dict[str, np.ndarray]]]
312+
The estimates of inference variables in a nested dictionary.
313+
314+
1. Each first-level key is the name of an inference variable.
315+
2. Each second-level key is the name of a scoring rule.
316+
3. (If the scoring rule comprises multiple estimators, each third-level key is the name of an estimator.)
317+
318+
Each estimator output (i.e., dictionary value that is not itself a dictionary) is an array
319+
of shape (num_datasets, point_estimate_size, variable_block_size).
320+
"""
321+
return self.approximator.estimate(conditions=conditions, **kwargs)
322+
293323
def log_prob(self, data: dict[str, np.ndarray], **kwargs) -> np.ndarray:
294324
"""
295325
Compute the log probability of given variables under the approximator.

0 commit comments

Comments
 (0)