Skip to content

Commit c5a5d6b

Browse files
committed
correctly track train / validation losses
1 parent 4781e2e commit c5a5d6b

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

bayesflow/approximators/backend_approximators/tensorflow_approximator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
1212

1313
def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
1414
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
15-
return self.compute_metrics(**kwargs)
15+
metrics = self.compute_metrics(**kwargs)
16+
loss = metrics["loss"]
17+
self._loss_tracker.update_state(loss)
18+
return metrics
1619

1720
def train_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
1821
with tf.GradientTape() as tape:

bayesflow/approximators/backend_approximators/torch_approximator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]:
1212

1313
def test_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
1414
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
15-
return self.compute_metrics(**kwargs)
15+
metrics = self.compute_metrics(**kwargs)
16+
loss = metrics["loss"]
17+
self._loss_tracker.update_state(loss)
18+
return metrics
1619

1720
def train_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
1821
with torch.enable_grad():

bayesflow/approximators/continuous_approximator.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,9 @@ def compute_metrics(
150150
inference_variables, conditions=inference_conditions, sample_weight=sample_weight, stage=stage
151151
)
152152

153-
loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
153+
loss = inference_metrics["loss"] + summary_metrics.get("loss", keras.ops.zeros(()))
154154

155-
inference_metrics = {f"{key}/inference_{key}": value for key, value in inference_metrics.items()}
156-
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}
157-
158-
metrics = {"loss": loss} | inference_metrics | summary_metrics
159-
160-
return metrics
155+
return {"loss": loss}
161156

162157
def fit(self, *args, **kwargs):
163158
"""

0 commit comments

Comments
 (0)