Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def build_adapter(
inference_variables: Sequence[str],
inference_conditions: Sequence[str] = None,
summary_variables: Sequence[str] = None,
sample_weight: Sequence[str] = None,
) -> Adapter:
"""Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator.

Expand All @@ -64,6 +65,8 @@ def build_adapter(
Names of the inference conditions in the data
summary_variables : Sequence of str, optional
Names of the summary variables in the data
sample_weight : str, optional
Name of the sample weights
"""
adapter = Adapter()
adapter.to_array()
Expand All @@ -77,8 +80,11 @@ def build_adapter(
adapter.as_set(summary_variables)
adapter.concatenate(summary_variables, into="summary_variables")

adapter.keep(["inference_variables", "inference_conditions", "summary_variables"])
adapter.standardize()
if sample_weight is not None:
adapter = adapter.rename(sample_weight, "sample_weight")

adapter.keep(["inference_variables", "inference_conditions", "summary_variables", "sample_weight"])
adapter.standardize(exclude="sample_weight")

return adapter

Expand All @@ -105,6 +111,7 @@ def compute_metrics(
inference_variables: Tensor,
inference_conditions: Tensor = None,
summary_variables: Tensor = None,
sample_weight: Tensor = None,
stage: str = "training",
) -> dict[str, Tensor]:
if self.summary_network is None:
Expand All @@ -128,7 +135,7 @@ def compute_metrics(
# Force a conversion to Tensor
inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
inference_metrics = self.inference_network.compute_metrics(
inference_variables, conditions=inference_conditions, stage=stage
inference_variables, conditions=inference_conditions, sample_weight=sample_weight, stage=stage
)

loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
Expand Down
10 changes: 7 additions & 3 deletions bayesflow/experimental/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ def _sample_v(self, x):
raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.")
return v

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
) -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage)
# sample random vector
v = self._sample_v(x)

Expand All @@ -236,6 +238,8 @@ def decode(z):
nll = -self.base_distribution.log_prob(z)
maximum_likelihood_loss = nll - surrogate
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss)

losses = maximum_likelihood_loss + self.beta * reconstruction_loss
loss = self.aggregate(losses, sample_weight)

return base_metrics | {"loss": loss}
10 changes: 7 additions & 3 deletions bayesflow/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,14 @@ def _inverse(

return x

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
) -> dict[str, Tensor]:
if sample_weight is not None:
print(sample_weight)
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage)

z, log_density = self(x, conditions=conditions, inverse=False, density=True)
loss = -keras.ops.mean(log_density)
loss = self.aggregate(-log_density, sample_weight)

return base_metrics | {"loss": loss}
10 changes: 7 additions & 3 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ def deltas(time, xz):
return x

def compute_metrics(
self, x: Tensor | Sequence[Tensor, ...], conditions: Tensor = None, stage: str = "training"
self,
x: Tensor | Sequence[Tensor, ...],
conditions: Tensor = None,
sample_weight: Tensor = None,
stage: str = "training",
) -> dict[str, Tensor]:
if isinstance(x, Sequence):
# already pre-configured
Expand All @@ -248,11 +252,11 @@ def compute_metrics(
x = t * x1 + (1 - t) * x0
target_velocity = x1 - x0

base_metrics = super().compute_metrics(x1, conditions, stage)
base_metrics = super().compute_metrics(x1, conditions, sample_weight, stage)

predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training")

loss = self.loss_fn(target_velocity, predicted_velocity)
loss = keras.ops.mean(loss)
loss = self.aggregate(loss, sample_weight)

return base_metrics | {"loss": loss}
11 changes: 10 additions & 1 deletion bayesflow/networks/inference_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tens
_, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs)
return log_density

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
) -> dict[str, Tensor]:
if not self.built:
xz_shape = keras.ops.shape(x)
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
Expand All @@ -64,3 +66,10 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
metrics[metric.name] = metric(samples, x)

return metrics

def aggregate(self, losses: Tensor, weights: Tensor = None):
if weights is not None:
weighted = losses * weights
else:
weighted = losses
return keras.ops.mean(weighted)
6 changes: 4 additions & 2 deletions bayesflow/networks/point_inference_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,15 @@ def call(
}
return output

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
) -> dict[str, Tensor]:
output = self(x, conditions)

metrics = {}
# calculate negative score as mean over all scores
for score_key, score in self.scores.items():
score_value = score.score(output[score_key], x)
score_value = score.score(output[score_key], x, sample_weight)
metrics[score_key] = score_value
neg_score = keras.ops.mean(list(metrics.values()))

Expand Down
39 changes: 37 additions & 2 deletions tests/test_approximators/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from tests.utils import check_combination_simulator_adapter


@pytest.fixture()
Expand Down Expand Up @@ -96,7 +97,7 @@ def approximator(request):


@pytest.fixture()
def adapter():
def adapter_without_sample_weight():
from bayesflow import ContinuousApproximator

return ContinuousApproximator.build_adapter(
Expand All @@ -106,14 +107,48 @@ def adapter():


@pytest.fixture()
def simulator():
def adapter_with_sample_weight():
from bayesflow import ContinuousApproximator

return ContinuousApproximator.build_adapter(
inference_variables=["mean", "std"],
inference_conditions=["x"],
sample_weight="weight",
)


@pytest.fixture(params=["adapter_without_sample_weight", "adapter_with_sample_weight"])
def adapter(request):
return request.getfixturevalue(request.param)


@pytest.fixture()
def normal_simulator():
from tests.utils.normal_simulator import NormalSimulator

return NormalSimulator()


@pytest.fixture()
def normal_simulator_with_sample_weight():
from tests.utils.normal_simulator import NormalSimulator
from bayesflow import make_simulator

def weight(mean):
return dict(weight=1.0)

return make_simulator([NormalSimulator(), weight])


@pytest.fixture(params=["normal_simulator", "normal_simulator_with_sample_weight"])
def simulator(request):
return request.getfixturevalue(request.param)


@pytest.fixture()
def train_dataset(batch_size, adapter, simulator):
check_combination_simulator_adapter(simulator, adapter)

from bayesflow import OfflineDataset

num_batches = 4
Expand Down
3 changes: 2 additions & 1 deletion tests/test_approximators/test_estimate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import keras
from tests.utils import check_combination_simulator_adapter


def test_approximator_estimate(approximator, simulator, batch_size, adapter):
approximator = approximator
check_combination_simulator_adapter(simulator, adapter)

num_batches = 4
data = simulator.sample((num_batches * batch_size,))
Expand Down
4 changes: 4 additions & 0 deletions tests/test_approximators/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import io
from contextlib import redirect_stdout
from tests.utils import check_approximator_multivariate_normal_score


@pytest.mark.skip(reason="not implemented")
Expand All @@ -19,6 +20,9 @@ def test_fit(amortizer, dataset):


def test_loss_progress(approximator, train_dataset, validation_dataset):
# as long as MultivariateNormalScore is unstable, skip fit progress test
check_approximator_multivariate_normal_score(approximator)

approximator.compile(optimizer="AdamW")
num_epochs = 3

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import keras
import numpy as np
from bayesflow.scores import ParametricDistributionScore
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score


def test_approximator_sample(point_approximator, simulator, batch_size, num_samples, adapter):
check_combination_simulator_adapter(simulator, adapter)

# as long as MultivariateNormalScore is unstable, skip test
check_approximator_multivariate_normal_score(point_approximator)

data = simulator.sample((batch_size,))

batch = adapter(data)
Expand Down
1 change: 1 addition & 0 deletions tests/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .callbacks import *
from .ops import *
from .ecdf import *
from .check_combinations import *
31 changes: 31 additions & 0 deletions tests/utils/check_combinations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest


def check_combination_simulator_adapter(simulator, adapter):
"""Make sure simulator and adapter fixtures fit together and appropriate errors are raised if not."""
# check whether the simulator returns a 'weight' key
simulator_with_sample_weight = "weight" in simulator.sample(1).keys()
# scan adapter representation for occurance of a rename pattern for 'sample_weight'
adapter_with_sample_weight = "-> 'sample_weight'" in str(adapter)

if not simulator_with_sample_weight and adapter_with_sample_weight:
# adapter should expect a 'weight' key and raise a KeyError.
with pytest.raises(KeyError):
adapter(simulator.sample(1))
# Don't use this fixture combination for further tests.
pytest.skip()
elif simulator_with_sample_weight and not adapter_with_sample_weight:
# When a weight key is present, but the adapter does not configure it
# to be used as sample weight, no error is raised currently.
# Don't use this fixture combination for further tests.
pytest.skip()


def check_approximator_multivariate_normal_score(approximator):
from bayesflow.approximators import PointApproximator
from bayesflow.scores import MultivariateNormalScore

if isinstance(approximator, PointApproximator):
for score in approximator.inference_network.scores.values():
if isinstance(score, MultivariateNormalScore):
pytest.skip()
2 changes: 2 additions & 0 deletions tests/utils/normal_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from bayesflow.simulators import Simulator
from bayesflow.types import Shape, Tensor
from bayesflow.utils.decorators import allow_batch_size


class NormalSimulator(Simulator):
"""TODO: Docstring"""

@allow_batch_size
def sample(self, batch_shape: Shape, num_observations: int = 32) -> dict[str, Tensor]:
mean = np.random.normal(0.0, 0.1, size=batch_shape + (2,))
std = np.random.lognormal(0.0, 0.1, size=batch_shape + (2,))
Expand Down