-
Notifications
You must be signed in to change notification settings - Fork 78
Flexible implementation of point estimation #281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
093785d
aa5dd93
60050ec
0cd3110
98437ae
ce07855
1a85fc8
ecef8ed
ab821bf
13ce858
daf69f6
3d639ed
21fb5f2
61267b9
63d448b
b69df85
ba5cc67
3d678cf
5cb5cc9
dac7e9b
50e01b1
95220cd
72bb994
1932724
38fe5bf
149d1ad
887be8f
849c67c
d8dc9a8
ff844fe
e67e91b
16a4d0a
1e92c8d
c051fcc
a9f10c3
69da9c8
b29d3fa
c9d6f3c
2ac5f3c
95cd950
d76ca9f
18f3a6e
7955cbd
39f17ee
d5cb12e
aa20868
c2ca810
26f6499
3f4e60d
e2e3b72
23c1dd8
a80011c
8aebe6a
0035e3d
06ea8df
3d56323
5a2ea53
40ccf08
011f0ce
4c0794f
8de49ac
94720f2
bff8d20
50167b3
23790c6
44105b7
bd812ef
c99fd01
a55612e
7114293
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| import keras | ||
| import numpy as np | ||
| from keras.saving import ( | ||
| register_keras_serializable as serializable, | ||
| ) | ||
|
|
||
| from bayesflow.types import Tensor | ||
| from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict | ||
| from .continuous_approximator import ContinuousApproximator | ||
|
|
||
|
|
||
| @serializable(package="bayesflow.approximators") | ||
| class PointApproximator(ContinuousApproximator): | ||
| """ | ||
| A workflow for fast amortized point estimation of a conditional distribution. | ||
|
|
||
| The distribution is approximated by point estimators, parameterized by a feed-forward `PointInferenceNetwork`. | ||
| Conditions can be compressed by an optional `SummaryNetwork` or used directly as input to the inference network. | ||
| """ | ||
|
|
||
| def estimate( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My two cents about nesting complexity.
The output structure of PointApproximator.estimate should identify the variable names and the point estimate kinds. To be as close as possible to the ContinuousApproximator.sample output I am quite happy with a variable name major nesting. Replacing tensors with their shapes this looks like the following To see it in context check the notebook
Having the functionality in the first place requires a few computations that are nested. Maybe some nice utility functions can aid here, but before coming up with those I wanted to err on the side of not hiding how it works. If it is not readable I should improve the comments. What do you think?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am fine with doubly nested outputs as you show them, assuming they are necessary, and we cannot efficiently do something like For readability, a loop might be preferred over dictionary comprehension.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Triple is necessary for scores that need multiple estimates. If you estimate mean and (co)variance for example. |
||
| self, | ||
| conditions: dict[str, np.ndarray], | ||
| split: bool = False, | ||
| **kwargs, | ||
| ) -> dict[str, dict[str, np.ndarray]]: | ||
| conditions = self._prepare_conditions(conditions, **kwargs) | ||
| estimates = self._estimate(**conditions, **kwargs) | ||
| estimates = self._apply_inverse_adapter_to_estimates(estimates, **kwargs) | ||
| # Optionally split the arrays along the last axis. | ||
| if split: | ||
| estimates = split_arrays(estimates, axis=-1) | ||
| # Reorder the nested dictionary so that original variable names are at the top. | ||
| estimates = self._reorder_estimates(estimates) | ||
| # Remove unnecessary nesting. | ||
| estimates = self._squeeze_estimates(estimates) | ||
|
|
||
| return estimates | ||
|
|
||
| def sample( | ||
| self, | ||
| *, | ||
| num_samples: int, | ||
| conditions: dict[str, np.ndarray], | ||
| split: bool = False, | ||
| **kwargs, | ||
| ) -> dict[str, np.ndarray]: | ||
| conditions = self._prepare_conditions(conditions, **kwargs) | ||
| samples = self._sample(num_samples, **conditions, **kwargs) | ||
| samples = self._apply_inverse_adapter_to_samples(samples, **kwargs) | ||
| # Optionally split the arrays along the last axis. | ||
| if split: | ||
| samples = split_arrays(samples, axis=-1) | ||
| # Squeeze samples if there's only one key-value pair. | ||
| samples = self._squeeze_samples(samples) | ||
|
|
||
| return samples | ||
|
|
||
| def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]: | ||
| """Adapts and converts the conditions to tensors.""" | ||
| conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) | ||
| return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions) | ||
|
|
||
| def _apply_inverse_adapter_to_estimates( | ||
| self, estimates: dict[str, dict[str, Tensor]], **kwargs | ||
| ) -> dict[str, dict[str, dict[str, np.ndarray]]]: | ||
| """Applies the inverse adapter on each inner element of the _estimate output dictionary.""" | ||
| estimates = keras.tree.map_structure(keras.ops.convert_to_numpy, estimates) | ||
| processed = {} | ||
| for score_key, score_val in estimates.items(): | ||
| processed[score_key] = {} | ||
| for head_key, estimate in score_val.items(): | ||
| adapted = self.adapter( | ||
| {"inference_variables": estimate}, | ||
| inverse=True, | ||
| strict=False, | ||
| **kwargs, | ||
| ) | ||
| processed[score_key][head_key] = adapted | ||
| return processed | ||
|
|
||
| def _apply_inverse_adapter_to_samples( | ||
| self, samples: dict[str, Tensor], **kwargs | ||
| ) -> dict[str, dict[str, np.ndarray]]: | ||
| """Applies the inverse adapter to a dictionary of samples.""" | ||
| samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples) | ||
| processed = {} | ||
| for score_key, samples in samples.items(): | ||
| processed[score_key] = self.adapter( | ||
| {"inference_variables": samples}, | ||
| inverse=True, | ||
| strict=False, | ||
| **kwargs, | ||
| ) | ||
| return processed | ||
|
|
||
| def _reorder_estimates( | ||
| self, estimates: dict[str, dict[str, dict[str, np.ndarray]]] | ||
| ) -> dict[str, dict[str, dict[str, np.ndarray]]]: | ||
| """Reorders the nested dictionary so that the inference variable names become the top-level keys.""" | ||
| # Grab the variable names from one sample inner dictionary. | ||
| sample_inner = next(iter(next(iter(estimates.values())).values())) | ||
| variable_names = sample_inner.keys() | ||
| reordered = {} | ||
| for variable in variable_names: | ||
| reordered[variable] = {} | ||
| for score_key, inner_dict in estimates.items(): | ||
| reordered[variable][score_key] = {inner_key: value[variable] for inner_key, value in inner_dict.items()} | ||
| return reordered | ||
|
|
||
| def _squeeze_estimates( | ||
| self, estimates: dict[str, dict[str, dict[str, np.ndarray]]] | ||
| ) -> dict[str, dict[str, np.ndarray]]: | ||
| """Squeezes each inner estimate dictionary to remove unnecessary nesting.""" | ||
| squeezed = {} | ||
| for variable, variable_estimates in estimates.items(): | ||
| squeezed[variable] = { | ||
| score_key: squeeze_inner_estimates_dict(inner_estimate) | ||
| for score_key, inner_estimate in variable_estimates.items() | ||
| } | ||
| return squeezed | ||
|
|
||
| def _squeeze_samples(self, samples: dict[str, np.ndarray]) -> np.ndarray or dict[str, np.ndarray]: | ||
| """Squeezes the samples dictionary to just the value if there is only one key-value pair.""" | ||
| if len(samples) == 1: | ||
| return next(iter(samples.values())) # Extract and return the only item's value | ||
| return samples | ||
|
|
||
| def _estimate( | ||
| self, | ||
| inference_conditions: Tensor = None, | ||
| summary_variables: Tensor = None, | ||
| **kwargs, | ||
| ) -> dict[str, dict[str, Tensor]]: | ||
| if self.summary_network is None: | ||
| if summary_variables is not None: | ||
| raise ValueError("Cannot use summary variables without a summary network.") | ||
| else: | ||
| if summary_variables is None: | ||
| raise ValueError("Summary variables are required when a summary network is present.") | ||
|
|
||
| summary_outputs = self.summary_network( | ||
| summary_variables, **filter_kwargs(kwargs, self.summary_network.call) | ||
| ) | ||
|
|
||
| if inference_conditions is None: | ||
| inference_conditions = summary_outputs | ||
| else: | ||
| inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=1) | ||
|
|
||
| return self.inference_network( | ||
| conditions=inference_conditions, | ||
| **filter_kwargs(kwargs, self.inference_network.call), | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very strongly nested, and I currently don't understand the reason for this. Keep in mind that objects you return to the user should be somewhat simple, otherwise users will get confused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ContinuousApproximator.estimator method is (just) a convenience method to bring ease interoperability with the PointApproximator.estimate method.
Here, the nesting is irreducible if we want output to be the same as with the PointApproximator.estimate method.
I commented on nesting complexity over there (PointApproximator.estimate) since this is where the original estimation takes place and I believe there is the real question of whether to refactor with respect to nesting.