Skip to content

Commit 7b2827a

Browse files
committed
Add sample weight support to change loss aggregation
1 parent a9f17de commit 7b2827a

File tree

5 files changed

+38
-12
lines changed

5 files changed

+38
-12
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def build_adapter(
4141
inference_variables: Sequence[str],
4242
inference_conditions: Sequence[str] = None,
4343
summary_variables: Sequence[str] = None,
44+
sample_weights: Sequence[str] = None,
4445
) -> Adapter:
4546
adapter = Adapter.create_default(inference_variables)
4647

@@ -50,7 +51,12 @@ def build_adapter(
5051
if summary_variables is not None:
5152
adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables")
5253

53-
adapter = adapter.keep(["inference_variables", "inference_conditions", "summary_variables"]).standardize()
54+
if sample_weights is not None: # we could provide automatic multiplication of different sample weights
55+
adapter = adapter.concatenate(sample_weights, into="sample_weights")
56+
57+
adapter = adapter.keep(
58+
["inference_variables", "inference_conditions", "summary_variables", "sample_weights"]
59+
).standardize(exclude="sample_weights")
5460

5561
return adapter
5662

@@ -77,6 +83,7 @@ def compute_metrics(
7783
inference_variables: Tensor,
7884
inference_conditions: Tensor = None,
7985
summary_variables: Tensor = None,
86+
sample_weights: Tensor = None,
8087
stage: str = "training",
8188
) -> dict[str, Tensor]:
8289
if self.summary_network is None:
@@ -98,7 +105,7 @@ def compute_metrics(
98105
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
99106

100107
inference_metrics = self.inference_network.compute_metrics(
101-
inference_variables, conditions=inference_conditions, stage=stage
108+
inference_variables, conditions=inference_conditions, sample_weights=sample_weights, stage=stage
102109
)
103110

104111
loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,12 @@ def _inverse(
117117

118118
return x
119119

120-
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
121-
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
120+
def compute_metrics(
121+
self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training"
122+
) -> dict[str, Tensor]:
123+
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weights=sample_weights, stage=stage)
122124

123125
z, log_density = self(x, conditions=conditions, inverse=False, density=True)
124-
loss = -keras.ops.mean(log_density)
126+
loss = self.aggregate(-log_density, sample_weights)
125127

126128
return base_metrics | {"loss": loss}

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,11 @@ def deltas(t, xz):
183183
return x
184184

185185
def compute_metrics(
186-
self, x: Tensor | Sequence[Tensor, ...], conditions: Tensor = None, stage: str = "training"
186+
self,
187+
x: Tensor | Sequence[Tensor, ...],
188+
conditions: Tensor = None,
189+
sample_weights: Tensor = None,
190+
stage: str = "training",
187191
) -> dict[str, Tensor]:
188192
if isinstance(x, Sequence):
189193
# already pre-configured
@@ -208,11 +212,11 @@ def compute_metrics(
208212
x = t * x1 + (1 - t) * x0
209213
target_velocity = x1 - x0
210214

211-
base_metrics = super().compute_metrics(x1, conditions, stage)
215+
base_metrics = super().compute_metrics(x1, conditions, sample_weights, stage)
212216

213217
predicted_velocity = self.velocity(x, t, conditions, training=stage == "training")
214218

215219
loss = self.loss_fn(target_velocity, predicted_velocity)
216-
loss = keras.ops.mean(loss)
220+
loss = self.aggregate(loss, sample_weights)
217221

218222
return base_metrics | {"loss": loss}

bayesflow/networks/free_form_flow/free_form_flow.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,10 @@ def _sample_v(self, x):
182182
raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.")
183183
return v
184184

185-
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
186-
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
185+
def compute_metrics(
186+
self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training"
187+
) -> dict[str, Tensor]:
188+
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weights=sample_weights, stage=stage)
187189
# sample random vector
188190
v = self._sample_v(x)
189191

@@ -204,6 +206,8 @@ def decode(z):
204206
nll = -self.base_distribution.log_prob(z)
205207
maximum_likelihood_loss = nll - surrogate
206208
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
207-
loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss)
209+
210+
losses = maximum_likelihood_loss + self.beta * reconstruction_loss
211+
loss = self.aggregate(losses, sample_weights)
208212

209213
return base_metrics | {"loss": loss}

bayesflow/networks/inference_network.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tens
4646
_, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs)
4747
return log_density
4848

49-
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
49+
def compute_metrics(
50+
self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training"
51+
) -> dict[str, Tensor]:
5052
if not self.built:
5153
xz_shape = keras.ops.shape(x)
5254
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
@@ -62,3 +64,10 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
6264
metrics[metric.name] = metric(samples, x)
6365

6466
return metrics
67+
68+
def aggregate(self, losses: Tensor, weights: Tensor = None):
69+
if weights is not None:
70+
weighted = losses * weights
71+
else:
72+
weighted = losses
73+
return keras.ops.mean(weighted)

0 commit comments

Comments
 (0)