Skip to content

Commit d2d9a36

Browse files
committed
Merge remote-tracking branch 'origin/dev' into dev
2 parents 1de8c07 + ffb212f commit d2d9a36

File tree

13 files changed

+128
-18
lines changed

13 files changed

+128
-18
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def build_adapter(
5353
inference_variables: Sequence[str],
5454
inference_conditions: Sequence[str] = None,
5555
summary_variables: Sequence[str] = None,
56+
sample_weight: Sequence[str] = None,
5657
) -> Adapter:
5758
"""Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator.
5859
@@ -64,6 +65,8 @@ def build_adapter(
6465
Names of the inference conditions in the data
6566
summary_variables : Sequence of str, optional
6667
Names of the summary variables in the data
68+
sample_weight : str, optional
69+
Name of the sample weights
6770
"""
6871
adapter = Adapter()
6972
adapter.to_array()
@@ -77,8 +80,11 @@ def build_adapter(
7780
adapter.as_set(summary_variables)
7881
adapter.concatenate(summary_variables, into="summary_variables")
7982

80-
adapter.keep(["inference_variables", "inference_conditions", "summary_variables"])
81-
adapter.standardize()
83+
if sample_weight is not None:
84+
adapter = adapter.rename(sample_weight, "sample_weight")
85+
86+
adapter.keep(["inference_variables", "inference_conditions", "summary_variables", "sample_weight"])
87+
adapter.standardize(exclude="sample_weight")
8288

8389
return adapter
8490

@@ -105,6 +111,7 @@ def compute_metrics(
105111
inference_variables: Tensor,
106112
inference_conditions: Tensor = None,
107113
summary_variables: Tensor = None,
114+
sample_weight: Tensor = None,
108115
stage: str = "training",
109116
) -> dict[str, Tensor]:
110117
if self.summary_network is None:
@@ -128,7 +135,7 @@ def compute_metrics(
128135
# Force a conversion to Tensor
129136
inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
130137
inference_metrics = self.inference_network.compute_metrics(
131-
inference_variables, conditions=inference_conditions, stage=stage
138+
inference_variables, conditions=inference_conditions, sample_weight=sample_weight, stage=stage
132139
)
133140

134141
loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))

bayesflow/experimental/free_form_flow/free_form_flow.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ def _sample_v(self, x):
214214
raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.")
215215
return v
216216

217-
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
218-
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
217+
def compute_metrics(
218+
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
219+
) -> dict[str, Tensor]:
220+
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage)
219221
# sample random vector
220222
v = self._sample_v(x)
221223

@@ -236,6 +238,8 @@ def decode(z):
236238
nll = -self.base_distribution.log_prob(z)
237239
maximum_likelihood_loss = nll - surrogate
238240
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
239-
loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss)
241+
242+
losses = maximum_likelihood_loss + self.beta * reconstruction_loss
243+
loss = self.aggregate(losses, sample_weight)
240244

241245
return base_metrics | {"loss": loss}

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,14 @@ def _inverse(
155155

156156
return x
157157

158-
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
159-
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
158+
def compute_metrics(
159+
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
160+
) -> dict[str, Tensor]:
161+
if sample_weight is not None:
162+
print(sample_weight)
163+
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage)
160164

161165
z, log_density = self(x, conditions=conditions, inverse=False, density=True)
162-
loss = -keras.ops.mean(log_density)
166+
loss = self.aggregate(-log_density, sample_weight)
163167

164168
return base_metrics | {"loss": loss}

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,11 @@ def deltas(time, xz):
225225
return x
226226

227227
def compute_metrics(
228-
self, x: Tensor | Sequence[Tensor, ...], conditions: Tensor = None, stage: str = "training"
228+
self,
229+
x: Tensor | Sequence[Tensor, ...],
230+
conditions: Tensor = None,
231+
sample_weight: Tensor = None,
232+
stage: str = "training",
229233
) -> dict[str, Tensor]:
230234
if isinstance(x, Sequence):
231235
# already pre-configured
@@ -250,11 +254,11 @@ def compute_metrics(
250254
x = t * x1 + (1 - t) * x0
251255
target_velocity = x1 - x0
252256

253-
base_metrics = super().compute_metrics(x1, conditions, stage)
257+
base_metrics = super().compute_metrics(x1, conditions, sample_weight, stage)
254258

255259
predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training")
256260

257261
loss = self.loss_fn(target_velocity, predicted_velocity)
258-
loss = keras.ops.mean(loss)
262+
loss = self.aggregate(loss, sample_weight)
259263

260264
return base_metrics | {"loss": loss}

bayesflow/networks/inference_network.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tens
4848
_, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs)
4949
return log_density
5050

51-
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
51+
def compute_metrics(
52+
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
53+
) -> dict[str, Tensor]:
5254
if not self.built:
5355
xz_shape = keras.ops.shape(x)
5456
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
@@ -64,3 +66,10 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
6466
metrics[metric.name] = metric(samples, x)
6567

6668
return metrics
69+
70+
def aggregate(self, losses: Tensor, weights: Tensor = None):
71+
if weights is not None:
72+
weighted = losses * weights
73+
else:
74+
weighted = losses
75+
return keras.ops.mean(weighted)

bayesflow/networks/point_inference_network.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,15 @@ def call(
144144
}
145145
return output
146146

147-
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
147+
def compute_metrics(
148+
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
149+
) -> dict[str, Tensor]:
148150
output = self(x, conditions)
149151

150152
metrics = {}
151153
# calculate negative score as mean over all scores
152154
for score_key, score in self.scores.items():
153-
score_value = score.score(output[score_key], x)
155+
score_value = score.score(output[score_key], x, sample_weight)
154156
metrics[score_key] = score_value
155157
neg_score = keras.ops.mean(list(metrics.values()))
156158

tests/test_approximators/conftest.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from tests.utils import check_combination_simulator_adapter
23

34

45
@pytest.fixture()
@@ -96,7 +97,7 @@ def approximator(request):
9697

9798

9899
@pytest.fixture()
99-
def adapter():
100+
def adapter_without_sample_weight():
100101
from bayesflow import ContinuousApproximator
101102

102103
return ContinuousApproximator.build_adapter(
@@ -106,14 +107,48 @@ def adapter():
106107

107108

108109
@pytest.fixture()
109-
def simulator():
110+
def adapter_with_sample_weight():
111+
from bayesflow import ContinuousApproximator
112+
113+
return ContinuousApproximator.build_adapter(
114+
inference_variables=["mean", "std"],
115+
inference_conditions=["x"],
116+
sample_weight="weight",
117+
)
118+
119+
120+
@pytest.fixture(params=["adapter_without_sample_weight", "adapter_with_sample_weight"])
121+
def adapter(request):
122+
return request.getfixturevalue(request.param)
123+
124+
125+
@pytest.fixture()
126+
def normal_simulator():
110127
from tests.utils.normal_simulator import NormalSimulator
111128

112129
return NormalSimulator()
113130

114131

132+
@pytest.fixture()
133+
def normal_simulator_with_sample_weight():
134+
from tests.utils.normal_simulator import NormalSimulator
135+
from bayesflow import make_simulator
136+
137+
def weight(mean):
138+
return dict(weight=1.0)
139+
140+
return make_simulator([NormalSimulator(), weight])
141+
142+
143+
@pytest.fixture(params=["normal_simulator", "normal_simulator_with_sample_weight"])
144+
def simulator(request):
145+
return request.getfixturevalue(request.param)
146+
147+
115148
@pytest.fixture()
116149
def train_dataset(batch_size, adapter, simulator):
150+
check_combination_simulator_adapter(simulator, adapter)
151+
117152
from bayesflow import OfflineDataset
118153

119154
num_batches = 4

tests/test_approximators/test_estimate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import keras
2+
from tests.utils import check_combination_simulator_adapter
23

34

45
def test_approximator_estimate(approximator, simulator, batch_size, adapter):
5-
approximator = approximator
6+
check_combination_simulator_adapter(simulator, adapter)
67

78
num_batches = 4
89
data = simulator.sample((num_batches * batch_size,))

tests/test_approximators/test_fit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
import io
55
from contextlib import redirect_stdout
6+
from tests.utils import check_approximator_multivariate_normal_score
67

78

89
@pytest.mark.skip(reason="not implemented")
@@ -19,6 +20,9 @@ def test_fit(amortizer, dataset):
1920

2021

2122
def test_loss_progress(approximator, train_dataset, validation_dataset):
23+
# as long as MultivariateNormalScore is unstable, skip fit progress test
24+
check_approximator_multivariate_normal_score(approximator)
25+
2226
approximator.compile(optimizer="AdamW")
2327
num_epochs = 3
2428

tests/test_approximators/test_point_approximators/test_sample.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import keras
22
import numpy as np
33
from bayesflow.scores import ParametricDistributionScore
4+
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
45

56

67
def test_approximator_sample(point_approximator, simulator, batch_size, num_samples, adapter):
8+
check_combination_simulator_adapter(simulator, adapter)
9+
10+
# as long as MultivariateNormalScore is unstable, skip test
11+
check_approximator_multivariate_normal_score(point_approximator)
12+
713
data = simulator.sample((batch_size,))
814

915
batch = adapter(data)

0 commit comments

Comments
 (0)