Skip to content

Commit 50167b3

Browse files
committed
Sample from parametric scoring rules; more refactoring in PointApproximator
1 parent bff8d20 commit 50167b3

File tree

4 files changed

+145
-22
lines changed

4 files changed

+145
-22
lines changed

bayesflow/approximators/point_approximator.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,9 @@ def estimate(
2727
if not self.built:
2828
raise AssertionError("PointApproximator needs to be built before predicting with it.")
2929

30-
# Prepare the input conditions.
3130
conditions = self._prepare_conditions(conditions, **kwargs)
32-
# Run the internal estimation and convert the output to numpy.
33-
estimates = self._run_inference(conditions, **kwargs)
34-
# Postprocess the inference output with the inverse adapter.
35-
estimates = self._apply_inverse_adapter(estimates, **kwargs)
31+
estimates = self._estimate(**conditions, **kwargs)
32+
estimates = self._apply_inverse_adapter_to_estimates(estimates, **kwargs)
3633
# Optionally split the arrays along the last axis.
3734
if split:
3835
estimates = split_arrays(estimates, axis=-1)
@@ -43,25 +40,40 @@ def estimate(
4340

4441
return estimates
4542

43+
def sample(
44+
self,
45+
*,
46+
num_samples: int,
47+
conditions: dict[str, np.ndarray],
48+
split: bool = False,
49+
**kwargs,
50+
) -> dict[str, np.ndarray]:
51+
if not self.built:
52+
raise AssertionError("This model needs to be built before using it for sampling.")
53+
54+
conditions = self._prepare_conditions(conditions, **kwargs)
55+
samples = self._sample(num_samples, **conditions, **kwargs)
56+
samples = self._apply_inverse_adapter_to_samples(samples, **kwargs)
57+
# Optionally split the arrays along the last axis.
58+
if split:
59+
samples = split_arrays(samples, axis=-1)
60+
# Squeeze samples if there's only one key-value pair.
61+
samples = self._squeeze_samples(samples)
62+
63+
return samples
64+
4665
def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
4766
"""Adapts and converts the conditions to tensors."""
4867
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
4968
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
5069

51-
def _run_inference(self, conditions: dict[str, Tensor], **kwargs) -> dict[str, dict[str, np.ndarray]]:
52-
"""Runs the internal _estimate function and converts the result to numpy arrays."""
53-
# Run the estimation.
54-
inference_output = self._estimate(**conditions, **kwargs)
55-
# Wrap the result in a dict and convert to numpy.
56-
wrapped_output = {"inference_variables": inference_output}
57-
return keras.tree.map_structure(keras.ops.convert_to_numpy, wrapped_output)
58-
59-
def _apply_inverse_adapter(
60-
self, estimates: dict[str, dict[str, np.ndarray]], **kwargs
70+
def _apply_inverse_adapter_to_estimates(
71+
self, estimates: dict[str, dict[str, Tensor]], **kwargs
6172
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
62-
"""Applies the inverse adapter on each inner element of the inference outputs."""
73+
"""Applies the inverse adapter on each inner element of the _estimate output dictionary."""
74+
estimates = keras.tree.map_structure(keras.ops.convert_to_numpy, estimates)
6375
processed = {}
64-
for score_key, score_val in estimates["inference_variables"].items():
76+
for score_key, score_val in estimates.items():
6577
processed[score_key] = {}
6678
for head_key, estimate in score_val.items():
6779
adapted = self.adapter(
@@ -73,6 +85,21 @@ def _apply_inverse_adapter(
7385
processed[score_key][head_key] = adapted
7486
return processed
7587

88+
def _apply_inverse_adapter_to_samples(
89+
self, samples: dict[str, Tensor], **kwargs
90+
) -> dict[str, dict[str, np.ndarray]]:
91+
"""Applies the inverse adapter to a dictionary of samples."""
92+
samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples)
93+
processed = {}
94+
for score_key, samples in samples.items():
95+
processed[score_key] = self.adapter(
96+
{"inference_variables": samples},
97+
inverse=True,
98+
strict=False,
99+
**kwargs,
100+
)
101+
return processed
102+
76103
def _reorder_estimates(
77104
self, estimates: dict[str, dict[str, dict[str, np.ndarray]]]
78105
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
@@ -99,6 +126,12 @@ def _squeeze_estimates(
99126
}
100127
return squeezed
101128

129+
def _squeeze_samples(self, samples: dict[str, np.ndarray]) -> np.ndarray or dict[str, np.ndarray]:
130+
"""Squeezes the samples dictionary to just the value if there is only one key-value pair."""
131+
if len(samples) == 1:
132+
return next(iter(samples.values())) # Extract and return the only item's value
133+
return samples
134+
102135
def _estimate(
103136
self,
104137
inference_conditions: Tensor = None,

bayesflow/networks/point_inference_network.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,29 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
165165

166166
# WIP: untested draft of sample method
167167
@allow_batch_size
168-
def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> dict[str, Tensor]:
169-
output = self.subnet(conditions)
168+
def sample(self, batch_shape: Shape, conditions: Tensor = None) -> dict[str, Tensor]:
169+
"""
170+
Parameters
171+
----------
172+
batch_shape : tuple,
173+
Expected dimensions depend on `conditions`
174+
- conditional sampling: (batch_size, num_samples) if `conditions` is a tensor
175+
of shape (batch_size, num_samples)
176+
- unconditional sampling: (num_samples,) if `conditions` is None
177+
conditions : Tensor or None, default None
178+
Optional inference conditions. If `conditions` is not given, the method will return unconditional samples.
179+
180+
Returns
181+
-------
182+
samples : dict[str, Tensor]
183+
Samples for every parametric scoring rule. Dict values have shape (batch_size, num_samples, num_variables)
184+
or (num_samples, num_variables) for conditional or unconditional sampling respectively.
185+
"""
186+
if conditions is None: # unconditional estimation uses a fixed input vector
187+
conditions = keras.ops.ones(batch_shape, dtype="float32").reshape(1, -1, 1)
188+
189+
# conditions are duplicated along axis 1 num_sample times
190+
output = self.subnet(conditions[:, 0, :])
170191
samples = {}
171192

172193
for score_key, score in self.scores.items():

tests/test_approximators/conftest.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ def batch_size():
66
return 8
77

88

9+
@pytest.fixture()
10+
def num_samples():
11+
return 100
12+
13+
914
@pytest.fixture()
1015
def summary_network():
1116
return None
@@ -32,18 +37,32 @@ def continuous_approximator(adapter, inference_network, summary_network):
3237
@pytest.fixture()
3338
def point_inference_network():
3439
from bayesflow.networks import PointInferenceNetwork
35-
from bayesflow.scores import NormedDifferenceScore, QuantileScore
40+
from bayesflow.scores import NormedDifferenceScore, QuantileScore, MultivariateNormalScore
3641

3742
return PointInferenceNetwork(
3843
scores=dict(
3944
mean=NormedDifferenceScore(k=2),
4045
quantiles=QuantileScore(q=[0.1, 0.5, 0.9]),
46+
mvn=MultivariateNormalScore(),
4147
),
4248
subnet="mlp",
4349
subnet_kwargs=dict(widths=(32, 32)),
4450
)
4551

4652

53+
@pytest.fixture()
54+
def point_inference_network_with_multiple_parametric_scores():
55+
from bayesflow.networks import PointInferenceNetwork
56+
from bayesflow.scores import MultivariateNormalScore
57+
58+
return PointInferenceNetwork(
59+
scores=dict(
60+
mvn1=MultivariateNormalScore(),
61+
mvn2=MultivariateNormalScore(),
62+
),
63+
)
64+
65+
4766
@pytest.fixture()
4867
def point_approximator(adapter, point_inference_network, summary_network):
4968
from bayesflow import PointApproximator
@@ -55,8 +74,23 @@ def point_approximator(adapter, point_inference_network, summary_network):
5574
)
5675

5776

58-
# @pytest.fixture(params=["continuous_approximator"], scope="function")
59-
@pytest.fixture(params=["continuous_approximator", "point_approximator"], scope="function")
77+
@pytest.fixture()
78+
def point_approximator_with_multiple_parametric_scores(
79+
adapter, point_inference_network_with_multiple_parametric_scores, summary_network
80+
):
81+
from bayesflow import PointApproximator
82+
83+
return PointApproximator(
84+
adapter=adapter,
85+
inference_network=point_inference_network_with_multiple_parametric_scores,
86+
summary_network=summary_network,
87+
)
88+
89+
90+
@pytest.fixture(
91+
params=["continuous_approximator", "point_approximator", "point_approximator_with_multiple_parametric_scores"],
92+
scope="function",
93+
)
6094
def approximator(request):
6195
return request.getfixturevalue(request.param)
6296

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import keras
2+
import numpy as np
3+
from bayesflow.scores import ParametricDistributionScore
4+
5+
6+
def test_approximator_sample(point_approximator, simulator, batch_size, num_samples, adapter):
7+
data = simulator.sample((batch_size,))
8+
9+
batch = adapter(data)
10+
point_approximator.build_from_data(batch)
11+
12+
samples = point_approximator.sample(num_samples=num_samples, conditions=data)
13+
14+
assert isinstance(samples, dict)
15+
16+
print(keras.tree.map_structure(keras.ops.shape, samples))
17+
18+
# Expect doubly nested sample dictionary if more than one samplable score is available.
19+
scores_for_sampling = [
20+
score
21+
for score in point_approximator.inference_network.scores.values()
22+
if isinstance(score, ParametricDistributionScore)
23+
]
24+
25+
if len(scores_for_sampling) > 1:
26+
for score_key, score_samples in samples.items():
27+
for variable, variable_estimates in score_samples.items():
28+
assert isinstance(variable_estimates, np.ndarray)
29+
assert variable_estimates.shape[:-1] == (batch_size, num_samples)
30+
31+
# If only one score is available, the outer nesting should be dropped.
32+
else:
33+
for variable, variable_estimates in samples.items():
34+
assert isinstance(variable_estimates, np.ndarray)
35+
assert variable_estimates.shape[:-1] == (batch_size, num_samples)

0 commit comments

Comments
 (0)