-
Notifications
You must be signed in to change notification settings - Fork 78
Aggregate logs in evaluate
#483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Implements cumulative logging and averaging in the evaluate method for all backend approximators to fix metrics aggregation issues.
- Introduce
_aggregate_logsand_mean_logshelpers to accumulate and normalize batch metrics. - Override
evaluatein Torch, TensorFlow, and JAX approximators using their respective*EpochIteratorand callback flows. - Wire up Keras
CallbackListand ensure per-batch callbacks with aggregated logs.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| bayesflow/approximators/backend_approximators/torch_approximator.py | Added _aggregate_logs, _mean_logs, and updated evaluate with TorchEpochIterator and callbacks. |
| bayesflow/approximators/backend_approximators/tensorflow_approximator.py | Added _aggregate_logs, _mean_logs, and updated evaluate with TFEpochIterator and callbacks. |
| bayesflow/approximators/backend_approximators/jax_approximator.py | Added _aggregate_logs, _mean_logs, and updated evaluate with JAXEpochIterator, state sync, and callbacks. |
Comments suppressed due to low confidence (2)
bayesflow/approximators/backend_approximators/jax_approximator.py:87
- [nitpick] The loop variable 'iterator' shadows the epoch_iterator and may be confusing; consider renaming it to 'batch_data' or similar to clarify its purpose.
for step, iterator in epoch_iterator:
bayesflow/approximators/backend_approximators/tensorflow_approximator.py:31
- Add unit tests for this new evaluate implementation to verify that log aggregation and averaging behave as expected across multiple batches.
def evaluate(
bayesflow/approximators/backend_approximators/jax_approximator.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds an aggregate option to the evaluate method in all backend approximators, allowing batch-wise metrics to be summed and averaged rather than overwritten each step.
- Introduce
_aggregate_fnand_reduce_fninevaluateto accumulate and average metrics. - Add
aggregateandreturn_dictparameters to control output format. - Ensure consistency in callback invocation and test function setup across Torch, TensorFlow, and JAX backends.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| bayesflow/approximators/backend_approximators/torch_approximator.py | Added evaluate override with optional aggregation logic |
| bayesflow/approximators/backend_approximators/tensorflow_approximator.py | Added evaluate override with optional aggregation logic |
| bayesflow/approximators/backend_approximators/jax_approximator.py | Added evaluate override with optional aggregation logic and JAX state handling |
Comments suppressed due to low confidence (4)
bayesflow/approximators/backend_approximators/torch_approximator.py:26
- Add unit tests covering both
aggregate=Trueandaggregate=Falsepaths to verify that metrics are correctly summed and averaged.
aggregate=False,
bayesflow/approximators/backend_approximators/torch_approximator.py:29
- Add a call to
self._assert_compile_called("evaluate")at the start ofevaluateto ensure the model has been compiled before evaluation.
# TODO: respect compiled trainable state
bayesflow/approximators/backend_approximators/jax_approximator.py:26
- The default
aggregate=Truein JAX differs fromaggregate=Falsein the Torch and TensorFlow backends. Align the default value for consistency across backends.
aggregate=True,
bayesflow/approximators/backend_approximators/torch_approximator.py:16
- This new
evaluatemethod lacks docstrings for theaggregateandreturn_dictparameters; please add descriptions and expected behavior.
def evaluate(
… aggregate-logs-in-evaluate # Conflicts: # bayesflow/approximators/backend_approximators/jax_approximator.py # bayesflow/approximators/backend_approximators/tensorflow_approximator.py # bayesflow/approximators/backend_approximators/torch_approximator.py
|
Superceded by #485 |
Fixes #481