Skip to content

Commit 22c75d1

Browse files
committed
Remove aggregate and fix sample weight
1 parent 0564092 commit 22c75d1

File tree

6 files changed

+40
-20
lines changed

6 files changed

+40
-20
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from bayesflow.types import Tensor
10-
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type
10+
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type, weighted_sum
1111

1212

1313
from ..inference_network import InferenceNetwork
@@ -285,7 +285,9 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
285285
out = skip * x + out * f
286286
return out
287287

288-
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
288+
def compute_metrics(
289+
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
290+
) -> dict[str, Tensor]:
289291
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
290292

291293
# The discretization schedule requires the number of passed training steps.
@@ -328,6 +330,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
328330
lam = 1 / (t2 - t1)
329331

330332
# Pseudo-huber loss, see [2], Section 3.3
331-
loss = ops.mean(lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber))
333+
loss = lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber)
334+
loss = weighted_sum(loss, sample_weight)
332335

333336
return base_metrics | {"loss": loss}

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
from keras.saving import register_keras_serializable as serializable
33

44
from bayesflow.types import Tensor
5-
from bayesflow.utils import find_permutation, keras_kwargs, serialize_value_or_type, deserialize_value_or_type
5+
from bayesflow.utils import (
6+
find_permutation,
7+
keras_kwargs,
8+
serialize_value_or_type,
9+
deserialize_value_or_type,
10+
weighted_sum,
11+
)
612

713
from .actnorm import ActNorm
814
from .couplings import DualCoupling
@@ -158,11 +164,9 @@ def _inverse(
158164
def compute_metrics(
159165
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
160166
) -> dict[str, Tensor]:
161-
if sample_weight is not None:
162-
print(sample_weight)
163-
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage)
167+
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
164168

165169
z, log_density = self(x, conditions=conditions, inverse=False, density=True)
166-
loss = self.aggregate(-log_density, sample_weight)
170+
loss = weighted_sum(-log_density, sample_weight)
167171

168172
return base_metrics | {"loss": loss}

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
optimal_transport,
1414
serialize_value_or_type,
1515
deserialize_value_or_type,
16+
weighted_sum,
1617
)
1718
from ..inference_network import InferenceNetwork
1819

@@ -254,11 +255,11 @@ def compute_metrics(
254255
x = t * x1 + (1 - t) * x0
255256
target_velocity = x1 - x0
256257

257-
base_metrics = super().compute_metrics(x1, conditions, sample_weight, stage)
258+
base_metrics = super().compute_metrics(x1, conditions=conditions, stage=stage)
258259

259260
predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training")
260261

261262
loss = self.loss_fn(target_velocity, predicted_velocity)
262-
loss = self.aggregate(loss, sample_weight)
263+
loss = weighted_sum(loss, sample_weight)
263264

264265
return base_metrics | {"loss": loss}

bayesflow/networks/inference_network.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tens
4848
_, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs)
4949
return log_density
5050

51-
def compute_metrics(
52-
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
53-
) -> dict[str, Tensor]:
51+
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
5452
if not self.built:
5553
xz_shape = keras.ops.shape(x)
5654
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
@@ -66,10 +64,3 @@ def compute_metrics(
6664
metrics[metric.name] = metric(samples, x)
6765

6866
return metrics
69-
70-
def aggregate(self, losses: Tensor, weights: Tensor = None):
71-
if weights is not None:
72-
weighted = losses * weights
73-
else:
74-
weighted = losses
75-
return keras.ops.mean(weighted)

bayesflow/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
tree_concatenate,
7272
tree_stack,
7373
fill_triangular_matrix,
74+
weighted_sum,
7475
)
7576
from .validators import check_lengths_same
7677
from .workflow_utils import find_inference_network, find_summary_network

bayesflow/utils/tensor_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,26 @@ def pad(x: Tensor, value: float | Tensor, n: int, axis: int, side: str = "both")
140140
raise TypeError(f"Invalid side type {type(side)!r}. Must be str.")
141141

142142

143+
def weighted_sum(elements: Tensor, weights: Tensor = None) -> Tensor:
144+
"""
145+
Compute the (optionally) weighted mean of the input tensor.
146+
147+
Parameters
148+
----------
149+
elements : Tensor
150+
A tensor containing the elements to average.
151+
weights : Tensor, optional
152+
A tensor of the same shape as `elements` representing weights.
153+
If None, the mean is computed without weights.
154+
155+
Returns
156+
-------
157+
Tensor
158+
A scalar tensor representing the (weighted) mean.
159+
"""
160+
return keras.ops.mean(elements * weights if weights is not None else elements)
161+
162+
143163
def searchsorted(sorted_sequence: Tensor, values: Tensor, side: str = "left") -> Tensor:
144164
"""
145165
Find indices where elements should be inserted to maintain order.

0 commit comments

Comments
 (0)