Skip to content

Commit 94720f2

Browse files
committed
Refactor nested processing of estimates
1 parent 8de49ac commit 94720f2

File tree

1 file changed

+64
-38
lines changed

1 file changed

+64
-38
lines changed

bayesflow/approximators/point_approximator.py

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,51 +27,77 @@ def estimate(
2727
if not self.built:
2828
raise AssertionError("PointApproximator needs to be built before predicting with it.")
2929

30+
# Prepare the input conditions.
31+
conditions = self._prepare_conditions(conditions, **kwargs)
32+
# Run the internal estimation and convert the output to numpy.
33+
estimates = self._run_inference(conditions, **kwargs)
34+
# Postprocess the inference output with the inverse adapter.
35+
estimates = self._apply_inverse_adapter(estimates, **kwargs)
36+
# Optionally split the arrays along the last axis.
37+
if split:
38+
estimates = split_arrays(estimates, axis=-1)
39+
# Reorder the nested dictionary so that original variable names are at the top.
40+
estimates = self._reorder_estimates(estimates)
41+
# Remove unnecessary nesting.
42+
estimates = self._squeeze_estimates(estimates)
43+
44+
return estimates
45+
46+
def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
47+
"""Adapts and converts the conditions to tensors."""
3048
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
31-
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
32-
conditions = {"inference_variables": self._estimate(**conditions, **kwargs)}
33-
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
34-
conditions = {
35-
outer_key: {
36-
inner_key: self.adapter(
37-
dict(inference_variables=conditions["inference_variables"][outer_key][inner_key]),
49+
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
50+
51+
def _run_inference(self, conditions: dict[str, Tensor], **kwargs) -> dict[str, dict[str, np.ndarray]]:
52+
"""Runs the internal _estimate function and converts the result to numpy arrays."""
53+
# Run the estimation.
54+
inference_output = self._estimate(**conditions, **kwargs)
55+
# Wrap the result in a dict and convert to numpy.
56+
wrapped_output = {"inference_variables": inference_output}
57+
return keras.tree.map_structure(keras.ops.convert_to_numpy, wrapped_output)
58+
59+
def _apply_inverse_adapter(
60+
self, estimates: dict[str, dict[str, np.ndarray]], **kwargs
61+
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
62+
"""Applies the inverse adapter on each inner element of the inference outputs."""
63+
processed = {}
64+
for score_key, score_val in estimates["inference_variables"].items():
65+
processed[score_key] = {}
66+
for head_key, estimate in score_val.items():
67+
adapted = self.adapter(
68+
{"inference_variables": estimate},
3869
inverse=True,
3970
strict=False,
4071
**kwargs,
4172
)
42-
for inner_key in conditions["inference_variables"][outer_key].keys()
43-
}
44-
for outer_key in conditions["inference_variables"].keys()
45-
}
73+
processed[score_key][head_key] = adapted
74+
return processed
4675

47-
if split:
48-
conditions = split_arrays(conditions, axis=-1)
49-
50-
# get original variable names to reorder them to highest level
51-
inference_variable_names = next(iter(next(iter(conditions.values())).values())).keys()
52-
53-
# change ordering of nested dictionary
54-
conditions = {
55-
variable_name: {
56-
outer_key: {
57-
inner_key: conditions[outer_key][inner_key][variable_name]
58-
for inner_key in conditions[outer_key].keys()
59-
}
60-
for outer_key in conditions.keys()
61-
}
62-
for variable_name in inference_variable_names
63-
}
64-
65-
# remove unnecessary nesting
66-
conditions = {
67-
variable_name: {
68-
outer_key: squeeze_inner_estimates_dict(conditions[variable_name][outer_key])
69-
for outer_key in conditions[variable_name].keys()
70-
}
71-
for variable_name in conditions.keys()
72-
}
76+
def _reorder_estimates(
77+
self, estimates: dict[str, dict[str, dict[str, np.ndarray]]]
78+
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
79+
"""Reorders the nested dictionary so that the inference variable names become the top-level keys."""
80+
# Grab the variable names from one sample inner dictionary.
81+
sample_inner = next(iter(next(iter(estimates.values())).values()))
82+
variable_names = sample_inner.keys()
83+
reordered = {}
84+
for variable in variable_names:
85+
reordered[variable] = {}
86+
for score_key, inner_dict in estimates.items():
87+
reordered[variable][score_key] = {inner_key: value[variable] for inner_key, value in inner_dict.items()}
88+
return reordered
7389

74-
return conditions
90+
def _squeeze_estimates(
91+
self, estimates: dict[str, dict[str, dict[str, np.ndarray]]]
92+
) -> dict[str, dict[str, np.ndarray]]:
93+
"""Squeezes each inner estimate dictionary to remove unnecessary nesting."""
94+
squeezed = {}
95+
for variable, variable_estimates in estimates.items():
96+
squeezed[variable] = {
97+
score_key: squeeze_inner_estimates_dict(inner_estimate)
98+
for score_key, inner_estimate in variable_estimates.items()
99+
}
100+
return squeezed
75101

76102
def _estimate(
77103
self,

0 commit comments

Comments
 (0)