Skip to content

Conversation

@LarsKue
Copy link
Contributor

@LarsKue LarsKue commented May 21, 2025

Fixes #481

@LarsKue LarsKue requested a review from Copilot May 21, 2025 19:56
@LarsKue LarsKue self-assigned this May 21, 2025
@LarsKue LarsKue added the fix Pull request that fixes a bug label May 21, 2025
@LarsKue LarsKue requested a review from stefanradev93 May 21, 2025 19:57
@codecov
Copy link

codecov bot commented May 21, 2025

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Implements cumulative logging and averaging in the evaluate method for all backend approximators to fix metrics aggregation issues.

  • Introduce _aggregate_logs and _mean_logs helpers to accumulate and normalize batch metrics.
  • Override evaluate in Torch, TensorFlow, and JAX approximators using their respective *EpochIterator and callback flows.
  • Wire up Keras CallbackList and ensure per-batch callbacks with aggregated logs.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
bayesflow/approximators/backend_approximators/torch_approximator.py Added _aggregate_logs, _mean_logs, and updated evaluate with TorchEpochIterator and callbacks.
bayesflow/approximators/backend_approximators/tensorflow_approximator.py Added _aggregate_logs, _mean_logs, and updated evaluate with TFEpochIterator and callbacks.
bayesflow/approximators/backend_approximators/jax_approximator.py Added _aggregate_logs, _mean_logs, and updated evaluate with JAXEpochIterator, state sync, and callbacks.
Comments suppressed due to low confidence (2)

bayesflow/approximators/backend_approximators/jax_approximator.py:87

  • [nitpick] The loop variable 'iterator' shadows the epoch_iterator and may be confusing; consider renaming it to 'batch_data' or similar to clarify its purpose.
for step, iterator in epoch_iterator:

bayesflow/approximators/backend_approximators/tensorflow_approximator.py:31

  • Add unit tests for this new evaluate implementation to verify that log aggregation and averaging behave as expected across multiple batches.
def evaluate(

@LarsKue LarsKue requested a review from Copilot May 21, 2025 20:11
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds an aggregate option to the evaluate method in all backend approximators, allowing batch-wise metrics to be summed and averaged rather than overwritten each step.

  • Introduce _aggregate_fn and _reduce_fn in evaluate to accumulate and average metrics.
  • Add aggregate and return_dict parameters to control output format.
  • Ensure consistency in callback invocation and test function setup across Torch, TensorFlow, and JAX backends.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
bayesflow/approximators/backend_approximators/torch_approximator.py Added evaluate override with optional aggregation logic
bayesflow/approximators/backend_approximators/tensorflow_approximator.py Added evaluate override with optional aggregation logic
bayesflow/approximators/backend_approximators/jax_approximator.py Added evaluate override with optional aggregation logic and JAX state handling
Comments suppressed due to low confidence (4)

bayesflow/approximators/backend_approximators/torch_approximator.py:26

  • Add unit tests covering both aggregate=True and aggregate=False paths to verify that metrics are correctly summed and averaged.
        aggregate=False,

bayesflow/approximators/backend_approximators/torch_approximator.py:29

  • Add a call to self._assert_compile_called("evaluate") at the start of evaluate to ensure the model has been compiled before evaluation.
# TODO: respect compiled trainable state

bayesflow/approximators/backend_approximators/jax_approximator.py:26

  • The default aggregate=True in JAX differs from aggregate=False in the Torch and TensorFlow backends. Align the default value for consistency across backends.
        aggregate=True,

bayesflow/approximators/backend_approximators/torch_approximator.py:16

  • This new evaluate method lacks docstrings for the aggregate and return_dict parameters; please add descriptions and expected behavior.
    def evaluate(

LarsKue added 4 commits May 21, 2025 16:14
… aggregate-logs-in-evaluate

# Conflicts:
#	bayesflow/approximators/backend_approximators/jax_approximator.py
#	bayesflow/approximators/backend_approximators/tensorflow_approximator.py
#	bayesflow/approximators/backend_approximators/torch_approximator.py
@LarsKue
Copy link
Contributor Author

LarsKue commented May 22, 2025

Superceded by #485

@LarsKue LarsKue closed this May 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fix Pull request that fixes a bug

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants