Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
return self.compute_metrics(**kwargs)
metrics = self.compute_metrics(**kwargs)
loss = metrics["loss"]
self._loss_tracker.update_state(loss)
return metrics

Check warning on line 18 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L15-L18

Added lines #L15 - L18 were not covered by tests

def train_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
with tf.GradientTape() as tape:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

def test_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
return self.compute_metrics(**kwargs)
metrics = self.compute_metrics(**kwargs)
loss = metrics["loss"]
self._loss_tracker.update_state(loss)
return metrics

Check warning on line 18 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L15-L18

Added lines #L15 - L18 were not covered by tests

def train_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
with torch.enable_grad():
Expand Down
9 changes: 2 additions & 7 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,9 @@ def compute_metrics(
inference_variables, conditions=inference_conditions, sample_weight=sample_weight, stage=stage
)

loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
loss = inference_metrics["loss"] + summary_metrics.get("loss", keras.ops.zeros(()))

inference_metrics = {f"{key}/inference_{key}": value for key, value in inference_metrics.items()}
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}

metrics = {"loss": loss} | inference_metrics | summary_metrics

return metrics
return {"loss": loss}

def fit(self, *args, **kwargs):
"""
Expand Down
Loading