@@ -23,17 +23,17 @@ def estimate(
2323 conditions : dict [str , np .ndarray ],
2424 split : bool = False ,
2525 ** kwargs ,
26- ) -> dict [str , dict [str , np .ndarray ]]:
26+ ) -> dict [str , dict [str , np .ndarray | dict [ str , np . ndarray ] ]]:
2727 """
28- Provides point estimates based on provided conditions (e.g., observables) .
28+ Estimates point summaries of inference variables based on specified conditions.
2929
3030 This method processes input conditions, computes estimates, applies necessary adapter transformations,
3131 and optionally splits the resulting arrays along the last axis.
3232
3333 Parameters
3434 ----------
3535 conditions : dict[str, np.ndarray]
36- A dictionary mapping variable names to NumPy arrays representing the conditions
36+ A dictionary mapping variable names to arrays representing the conditions
3737 for the estimation process.
3838 split : bool, optional
3939 If True, the estimated arrays are split along the last axis, by default False.
@@ -42,9 +42,15 @@ def estimate(
4242
4343 Returns
4444 -------
45- dict[str, dict[str, np.ndarray]]
46- A nested dictionary where the top-level keys correspond to original variable names,
47- and values contain dictionaries mapping estimation results to NumPy arrays.
45+ estimates : dict[str, dict[str, np.ndarray or dict[str, np.ndarray]]]
46+ The estimates of inference variables in a nested dictionary.
47+
48+ 1. Each first-level key is the name of an inference variable.
49+ 2. Each second-level key is the name of a scoring rule.
50+ 3. (If the scoring rule comprises multiple estimators, each third-level key is the name of an estimator.)
51+
52+ Each estimator output (i.e., dictionary value that is not itself a dictionary) is an array
53+ of shape (num_datasets, point_estimate_size, variable_block_size).
4854 """
4955
5056 conditions = self ._prepare_conditions (conditions , ** kwargs )
@@ -67,39 +73,43 @@ def sample(
6773 conditions : dict [str , np .ndarray ],
6874 split : bool = False ,
6975 ** kwargs ,
70- ) -> dict [str , np .ndarray ]:
76+ ) -> dict [str , dict [ str , np .ndarray ] ]:
7177 """
72- Generate samples from point estimates based on provided conditions. These samples
73- will generally not correspond to samples from the fully Bayesian posterior, since
74- they will assume some parametric form (e.g., Gaussian in the case of mean score).
78+ Draws samples from a parametric distribution based on point estimates for given input conditions.
7579
76- This method draws a specified number of samples according to the given conditions,
77- applies necessary transformations, and optionally splits the resulting arrays along the last axis .
80+ These samples will generally not correspond to samples from the fully Bayesian posterior, since
81+ they will assume some parametric form (e.g., multivariate normal when using the MultivariateNormalScore) .
7882
7983 Parameters
8084 ----------
8185 num_samples : int
8286 The number of samples to generate.
8387 conditions : dict[str, np.ndarray]
84- A dictionary mapping variable names to NumPy arrays representing the conditions
88+ A dictionary mapping variable names to arrays representing the conditions
8589 for the sampling process.
8690 split : bool, optional
8791 If True, the sampled arrays are split along the last axis, by default False.
92+ Currently not supported for `PointApproximator`.
8893 **kwargs
8994 Additional keyword arguments passed to underlying processing functions.
9095
9196 Returns
9297 -------
93- dict[str, np.ndarray]
94- A dictionary where keys correspond to variable names and values are NumPy arrays
95- containing the generated samples.
96- """
98+ samples : dict[str, np.ndarray or dict[str, np.ndarray]]
99+ Samples for all inference variables and all parametric scoring rules in a nested dictionary.
100+
101+ 1. Each first-level key is the name of an inference variable.
102+ 2. (If there are multiple parametric scores, each second-level key is the name of such a score.)
97103
104+ Each output (i.e., dictionary value that is not itself a dictionary) is an array
105+ of shape (num_datasets, num_samples, variable_block_size).
106+ """
98107 conditions = self ._prepare_conditions (conditions , ** kwargs )
99108 samples = self ._sample (num_samples , ** conditions , ** kwargs )
100109 samples = self ._apply_inverse_adapter_to_samples (samples , ** kwargs )
101110 # Optionally split the arrays along the last axis.
102111 if split :
112+ raise NotImplementedError ("split=True is currently not supported for `PointApproximator`." )
103113 samples = split_arrays (samples , axis = - 1 )
104114 # Squeeze samples if there's only one key-value pair.
105115 samples = self ._squeeze_samples (samples )
0 commit comments