From 7b2827a2242e5df326207995b9e1a7f43c6f9f7a Mon Sep 17 00:00:00 2001 From: han-ol Date: Wed, 12 Feb 2025 23:14:55 +0100 Subject: [PATCH 1/4] Add sample weight support to change loss aggregation --- bayesflow/approximators/continuous_approximator.py | 11 +++++++++-- bayesflow/networks/coupling_flow/coupling_flow.py | 8 +++++--- bayesflow/networks/flow_matching/flow_matching.py | 10 +++++++--- bayesflow/networks/free_form_flow/free_form_flow.py | 10 +++++++--- bayesflow/networks/inference_network.py | 11 ++++++++++- 5 files changed, 38 insertions(+), 12 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 5c389b329..16a7676cf 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -41,6 +41,7 @@ def build_adapter( inference_variables: Sequence[str], inference_conditions: Sequence[str] = None, summary_variables: Sequence[str] = None, + sample_weights: Sequence[str] = None, ) -> Adapter: adapter = Adapter.create_default(inference_variables) @@ -50,7 +51,12 @@ def build_adapter( if summary_variables is not None: adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables") - adapter = adapter.keep(["inference_variables", "inference_conditions", "summary_variables"]).standardize() + if sample_weights is not None: # we could provide automatic multiplication of different sample weights + adapter = adapter.concatenate(sample_weights, into="sample_weights") + + adapter = adapter.keep( + ["inference_variables", "inference_conditions", "summary_variables", "sample_weights"] + ).standardize(exclude="sample_weights") return adapter @@ -77,6 +83,7 @@ def compute_metrics( inference_variables: Tensor, inference_conditions: Tensor = None, summary_variables: Tensor = None, + sample_weights: Tensor = None, stage: str = "training", ) -> dict[str, Tensor]: if self.summary_network is None: @@ -98,7 +105,7 @@ def compute_metrics( inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1) inference_metrics = self.inference_network.compute_metrics( - inference_variables, conditions=inference_conditions, stage=stage + inference_variables, conditions=inference_conditions, sample_weights=sample_weights, stage=stage ) loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(())) diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index a357d52d8..cd2f2dece 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -117,10 +117,12 @@ def _inverse( return x - def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: - base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) + def compute_metrics( + self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training" + ) -> dict[str, Tensor]: + base_metrics = super().compute_metrics(x, conditions=conditions, sample_weights=sample_weights, stage=stage) z, log_density = self(x, conditions=conditions, inverse=False, density=True) - loss = -keras.ops.mean(log_density) + loss = self.aggregate(-log_density, sample_weights) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 07731edbe..d4798b0c3 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -183,7 +183,11 @@ def deltas(t, xz): return x def compute_metrics( - self, x: Tensor | Sequence[Tensor, ...], conditions: Tensor = None, stage: str = "training" + self, + x: Tensor | Sequence[Tensor, ...], + conditions: Tensor = None, + sample_weights: Tensor = None, + stage: str = "training", ) -> dict[str, Tensor]: if isinstance(x, Sequence): # already pre-configured @@ -208,11 +212,11 @@ def compute_metrics( x = t * x1 + (1 - t) * x0 target_velocity = x1 - x0 - base_metrics = super().compute_metrics(x1, conditions, stage) + base_metrics = super().compute_metrics(x1, conditions, sample_weights, stage) predicted_velocity = self.velocity(x, t, conditions, training=stage == "training") loss = self.loss_fn(target_velocity, predicted_velocity) - loss = keras.ops.mean(loss) + loss = self.aggregate(loss, sample_weights) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/free_form_flow/free_form_flow.py b/bayesflow/networks/free_form_flow/free_form_flow.py index fd5ca180a..27e2826fd 100644 --- a/bayesflow/networks/free_form_flow/free_form_flow.py +++ b/bayesflow/networks/free_form_flow/free_form_flow.py @@ -182,8 +182,10 @@ def _sample_v(self, x): raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.") return v - def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: - base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) + def compute_metrics( + self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training" + ) -> dict[str, Tensor]: + base_metrics = super().compute_metrics(x, conditions=conditions, sample_weights=sample_weights, stage=stage) # sample random vector v = self._sample_v(x) @@ -204,6 +206,8 @@ def decode(z): nll = -self.base_distribution.log_prob(z) maximum_likelihood_loss = nll - surrogate reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1) - loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss) + + losses = maximum_likelihood_loss + self.beta * reconstruction_loss + loss = self.aggregate(losses, sample_weights) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index 868be5582..8991d51ec 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -46,7 +46,9 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tens _, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs) return log_density - def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: + def compute_metrics( + self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training" + ) -> dict[str, Tensor]: if not self.built: xz_shape = keras.ops.shape(x) conditions_shape = None if conditions is None else keras.ops.shape(conditions) @@ -62,3 +64,10 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr metrics[metric.name] = metric(samples, x) return metrics + + def aggregate(self, losses: Tensor, weights: Tensor = None): + if weights is not None: + weighted = losses * weights + else: + weighted = losses + return keras.ops.mean(weighted) From 3f4ec657660bff176eb6f7591b92a38aa1cd39aa Mon Sep 17 00:00:00 2001 From: han-ol Date: Tue, 18 Mar 2025 12:04:00 +0100 Subject: [PATCH 2/4] Allow batch size in NormalSimulator --- tests/utils/normal_simulator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utils/normal_simulator.py b/tests/utils/normal_simulator.py index dc04c987f..6c44f8b86 100644 --- a/tests/utils/normal_simulator.py +++ b/tests/utils/normal_simulator.py @@ -2,11 +2,13 @@ from bayesflow.simulators import Simulator from bayesflow.types import Shape, Tensor +from bayesflow.utils.decorators import allow_batch_size class NormalSimulator(Simulator): """TODO: Docstring""" + @allow_batch_size def sample(self, batch_shape: Shape, num_observations: int = 32) -> dict[str, Tensor]: mean = np.random.normal(0.0, 0.1, size=batch_shape + (2,)) std = np.random.lognormal(0.0, 0.1, size=batch_shape + (2,)) From 3d6a8376e584dbbf4ef0573ccb78e4710893802c Mon Sep 17 00:00:00 2001 From: han-ol Date: Tue, 18 Mar 2025 12:11:28 +0100 Subject: [PATCH 3/4] Rename to sample_weight, add tests with and without sample weights --- .../approximators/continuous_approximator.py | 16 +++--- .../free_form_flow/free_form_flow.py | 6 +-- .../networks/coupling_flow/coupling_flow.py | 8 +-- .../networks/flow_matching/flow_matching.py | 6 +-- bayesflow/networks/inference_network.py | 2 +- tests/test_approximators/conftest.py | 53 ++++++++++++++++++- 6 files changed, 72 insertions(+), 19 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 1b022bbb5..73ad4e621 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -53,7 +53,7 @@ def build_adapter( inference_variables: Sequence[str], inference_conditions: Sequence[str] = None, summary_variables: Sequence[str] = None, - sample_weights: Sequence[str] = None, + sample_weight: Sequence[str] = None, ) -> Adapter: """Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator. @@ -65,6 +65,8 @@ def build_adapter( Names of the inference conditions in the data summary_variables : Sequence of str, optional Names of the summary variables in the data + sample_weight : str, optional + Name of the sample weights """ adapter = Adapter.create_default(inference_variables) @@ -74,12 +76,12 @@ def build_adapter( if summary_variables is not None: adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables") - if sample_weights is not None: # we could provide automatic multiplication of different sample weights - adapter = adapter.concatenate(sample_weights, into="sample_weights") + if sample_weight is not None: + adapter = adapter.rename(sample_weight, "sample_weight") adapter = adapter.keep( - ["inference_variables", "inference_conditions", "summary_variables", "sample_weights"] - ).standardize(exclude="sample_weights") + ["inference_variables", "inference_conditions", "summary_variables", "sample_weight"] + ).standardize(exclude="sample_weight") return adapter @@ -106,7 +108,7 @@ def compute_metrics( inference_variables: Tensor, inference_conditions: Tensor = None, summary_variables: Tensor = None, - sample_weights: Tensor = None, + sample_weight: Tensor = None, stage: str = "training", ) -> dict[str, Tensor]: if self.summary_network is None: @@ -128,7 +130,7 @@ def compute_metrics( inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1) inference_metrics = self.inference_network.compute_metrics( - inference_variables, conditions=inference_conditions, sample_weights=sample_weights, stage=stage + inference_variables, conditions=inference_conditions, sample_weight=sample_weight, stage=stage ) loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(())) diff --git a/bayesflow/experimental/free_form_flow/free_form_flow.py b/bayesflow/experimental/free_form_flow/free_form_flow.py index 8d1854f9d..a6be01fa2 100644 --- a/bayesflow/experimental/free_form_flow/free_form_flow.py +++ b/bayesflow/experimental/free_form_flow/free_form_flow.py @@ -215,9 +215,9 @@ def _sample_v(self, x): return v def compute_metrics( - self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training" + self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training" ) -> dict[str, Tensor]: - base_metrics = super().compute_metrics(x, conditions=conditions, sample_weights=sample_weights, stage=stage) + base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) # sample random vector v = self._sample_v(x) @@ -240,6 +240,6 @@ def decode(z): reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1) losses = maximum_likelihood_loss + self.beta * reconstruction_loss - loss = self.aggregate(losses, sample_weights) + loss = self.aggregate(losses, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index 46c1afe97..54f428521 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -156,11 +156,13 @@ def _inverse( return x def compute_metrics( - self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training" + self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training" ) -> dict[str, Tensor]: - base_metrics = super().compute_metrics(x, conditions=conditions, sample_weights=sample_weights, stage=stage) + if sample_weight is not None: + print(sample_weight) + base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) z, log_density = self(x, conditions=conditions, inverse=False, density=True) - loss = self.aggregate(-log_density, sample_weights) + loss = self.aggregate(-log_density, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 041fa23ce..912f9c64f 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -234,7 +234,7 @@ def compute_metrics( self, x: Tensor | Sequence[Tensor, ...], conditions: Tensor = None, - sample_weights: Tensor = None, + sample_weight: Tensor = None, stage: str = "training", ) -> dict[str, Tensor]: if isinstance(x, Sequence): @@ -260,11 +260,11 @@ def compute_metrics( x = t * x1 + (1 - t) * x0 target_velocity = x1 - x0 - base_metrics = super().compute_metrics(x1, conditions, sample_weights, stage) + base_metrics = super().compute_metrics(x1, conditions, sample_weight, stage) predicted_velocity = self.velocity(x, t, conditions, training=stage == "training") loss = self.loss_fn(target_velocity, predicted_velocity) - loss = self.aggregate(loss, sample_weights) + loss = self.aggregate(loss, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index 568c792fb..9a5838c53 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -49,7 +49,7 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tens return log_density def compute_metrics( - self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training" + self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training" ) -> dict[str, Tensor]: if not self.built: xz_shape = keras.ops.shape(x) diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 455d47fbe..b3c95c368 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -30,7 +30,7 @@ def approximator(adapter, inference_network, summary_network): @pytest.fixture() -def adapter(): +def adapter_without_sample_weight(): from bayesflow import ContinuousApproximator return ContinuousApproximator.build_adapter( @@ -40,14 +40,63 @@ def adapter(): @pytest.fixture() -def simulator(): +def adapter_with_sample_weight(): + from bayesflow import ContinuousApproximator + + return ContinuousApproximator.build_adapter( + inference_variables=["mean", "std"], + inference_conditions=["x"], + sample_weight="weight", + ) + + +@pytest.fixture(params=["adapter_without_sample_weight", "adapter_with_sample_weight"]) +def adapter(request): + return request.getfixturevalue(request.param) + + +@pytest.fixture() +def normal_simulator(): from tests.utils.normal_simulator import NormalSimulator return NormalSimulator() +@pytest.fixture() +def normal_simulator_with_sample_weight(): + from tests.utils.normal_simulator import NormalSimulator + from bayesflow import make_simulator + + def weight(mean): + return dict(weight=1.0) + + return make_simulator([NormalSimulator(), weight]) + + +@pytest.fixture(params=["normal_simulator", "normal_simulator_with_sample_weight"]) +def simulator(request): + return request.getfixturevalue(request.param) + + @pytest.fixture() def train_dataset(batch_size, adapter, simulator): + # scan adapter representation for occurance of a rename pattern for 'sample_weight' + adapter_with_sample_weight = "-> 'sample_weight'" in str(adapter) + # check whether the simulator returns a 'weight' key + simulator_with_sample_weight = "weight" in simulator.sample(1).keys() + + if adapter_with_sample_weight and not simulator_with_sample_weight: + # adapter should expect a 'weight' key and raise a KeyError. + with pytest.raises(KeyError): + adapter(simulator.sample(1)) + # Don't use this fixture combination for further tests. + pytest.skip() + elif not adapter_with_sample_weight and simulator_with_sample_weight: + # When a weight key is present, but the adapter does not configure it + # to be used as sample weight, no error is raised currently. + # Don't use this fixture combination for further tests. + pytest.skip() + from bayesflow import OfflineDataset num_batches = 4 From c4f27be54b23449b27c050d9e63becf0a9ae3137 Mon Sep 17 00:00:00 2001 From: han-ol Date: Tue, 18 Mar 2025 13:29:19 +0100 Subject: [PATCH 4/4] Skip flaky fit progress and sample test for multivariate normal score estimation --- tests/test_approximators/test_fit.py | 4 ++++ .../test_point_approximators/test_sample.py | 5 ++++- tests/utils/check_combinations.py | 10 ++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_approximators/test_fit.py b/tests/test_approximators/test_fit.py index 27d4716c4..b561efb77 100644 --- a/tests/test_approximators/test_fit.py +++ b/tests/test_approximators/test_fit.py @@ -3,6 +3,7 @@ import pytest import io from contextlib import redirect_stdout +from tests.utils import check_approximator_multivariate_normal_score @pytest.mark.skip(reason="not implemented") @@ -19,6 +20,9 @@ def test_fit(amortizer, dataset): def test_loss_progress(approximator, train_dataset, validation_dataset): + # as long as MultivariateNormalScore is unstable, skip fit progress test + check_approximator_multivariate_normal_score(approximator) + approximator.compile(optimizer="AdamW") num_epochs = 3 diff --git a/tests/test_approximators/test_point_approximators/test_sample.py b/tests/test_approximators/test_point_approximators/test_sample.py index 5755037a5..6f56828d9 100644 --- a/tests/test_approximators/test_point_approximators/test_sample.py +++ b/tests/test_approximators/test_point_approximators/test_sample.py @@ -1,12 +1,15 @@ import keras import numpy as np from bayesflow.scores import ParametricDistributionScore -from tests.utils import check_combination_simulator_adapter +from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score def test_approximator_sample(point_approximator, simulator, batch_size, num_samples, adapter): check_combination_simulator_adapter(simulator, adapter) + # as long as MultivariateNormalScore is unstable, skip test + check_approximator_multivariate_normal_score(point_approximator) + data = simulator.sample((batch_size,)) batch = adapter(data) diff --git a/tests/utils/check_combinations.py b/tests/utils/check_combinations.py index a7d85afcd..8d3fa5d46 100644 --- a/tests/utils/check_combinations.py +++ b/tests/utils/check_combinations.py @@ -19,3 +19,13 @@ def check_combination_simulator_adapter(simulator, adapter): # to be used as sample weight, no error is raised currently. # Don't use this fixture combination for further tests. pytest.skip() + + +def check_approximator_multivariate_normal_score(approximator): + from bayesflow.approximators import PointApproximator + from bayesflow.scores import MultivariateNormalScore + + if isinstance(approximator, PointApproximator): + for score in approximator.inference_network.scores.values(): + if isinstance(score, MultivariateNormalScore): + pytest.skip()