diff --git a/bayesflow/experimental/cif/cif.py b/bayesflow/experimental/cif/cif.py index e6b4c9a6c..2e1c893ff 100644 --- a/bayesflow/experimental/cif/cif.py +++ b/bayesflow/experimental/cif/cif.py @@ -99,7 +99,7 @@ def _inverse( def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]: base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) - elbo = self.log_prob(x, conditions=conditions) + elbo = self.log_prob(x, conditions=conditions, training=stage == "training") loss = -keras.ops.mean(elbo) diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index 9a73e1062..781e6148d 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -183,7 +183,7 @@ def compute_metrics( ) -> dict[str, Tensor]: base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) - z, log_density = self(x, conditions=conditions, inverse=False, density=True) + z, log_density = self(x, conditions=conditions, inverse=False, density=True, training=stage == "training") loss = weighted_mean(-log_density, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py index 402632355..2328d992f 100644 --- a/bayesflow/networks/point_inference_network.py +++ b/bayesflow/networks/point_inference_network.py @@ -145,7 +145,7 @@ def call( def compute_metrics( self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training" ) -> dict[str, Tensor]: - output = self(x, conditions) + output = self(x, conditions, training=stage == "training") metrics = {} # calculate negative score as mean over all scores