|
7 | 7 | import numpy as np |
8 | 8 |
|
9 | 9 | from bayesflow.types import Tensor |
10 | | -from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type |
| 10 | +from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type, weighted_sum |
11 | 11 |
|
12 | 12 |
|
13 | 13 | from ..inference_network import InferenceNetwork |
@@ -285,7 +285,9 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, |
285 | 285 | out = skip * x + out * f |
286 | 286 | return out |
287 | 287 |
|
288 | | - def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: |
| 288 | + def compute_metrics( |
| 289 | + self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training" |
| 290 | + ) -> dict[str, Tensor]: |
289 | 291 | base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) |
290 | 292 |
|
291 | 293 | # The discretization schedule requires the number of passed training steps. |
@@ -328,6 +330,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr |
328 | 330 | lam = 1 / (t2 - t1) |
329 | 331 |
|
330 | 332 | # Pseudo-huber loss, see [2], Section 3.3 |
331 | | - loss = ops.mean(lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber)) |
| 333 | + loss = lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber) |
| 334 | + loss = weighted_sum(loss, sample_weight) |
332 | 335 |
|
333 | 336 | return base_metrics | {"loss": loss} |
0 commit comments