From 9fa0743f5a900030337bb5efe7ec3d335592d534 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Wed, 21 May 2025 15:41:11 -0400 Subject: [PATCH 1/5] aggregate logs in evaluate --- .../backend_approximators/jax_approximator.py | 135 ++++++++++++++++++ .../tensorflow_approximator.py | 97 +++++++++++++ .../torch_approximator.py | 97 +++++++++++++ 3 files changed, 329 insertions(+) diff --git a/bayesflow/approximators/backend_approximators/jax_approximator.py b/bayesflow/approximators/backend_approximators/jax_approximator.py index a30be99c8..9ce7ae072 100644 --- a/bayesflow/approximators/backend_approximators/jax_approximator.py +++ b/bayesflow/approximators/backend_approximators/jax_approximator.py @@ -3,13 +3,148 @@ from bayesflow.utils import filter_kwargs +from keras.src.backend.jax.trainer import JAXEpochIterator +from keras.src import callbacks as callbacks_module + class JAXApproximator(keras.Model): + def _aggregate_logs(self, logs, step_logs): + if not logs: + return step_logs + + return keras.tree.map_structure(keras.ops.add, logs, step_logs) + + def _mean_logs(self, logs, total_steps): + if total_steps == 0: + return logs + + def _div(x): + return x / total_steps + + return keras.tree.map_structure(_div, logs) + # noinspection PyMethodOverriding def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]: # implemented by each respective architecture raise NotImplementedError + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + self._assert_compile_called("evaluate") + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = JAXEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + self._symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + self._record_training_state_sharding_spec() + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + total_steps = 0 + self.reset_metrics() + + self._jax_state_synced = True + with epoch_iterator.catch_stop_iteration(): + for step, iterator in epoch_iterator: + total_steps += 1 + callbacks.on_test_batch_begin(step) + + if self._jax_state_synced: + # The state may have been synced by a callback. + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + metrics_variables=True, + purge_model_variables=True, + ) + self._jax_state_synced = False + + # BAYESFLOW: save into step_logs instead of overwriting logs + step_logs, state = self.test_function(state, iterator) + ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = state + + # BAYESFLOW: aggregate the metrics across all iterations + logs = self._aggregate_logs(logs, step_logs) + + # Setting _jax_state enables callbacks to force a state sync + # if they need to. + self._jax_state = { + # I wouldn't recommend modifying non-trainable model state + # during evaluate(), but it's allowed. + "trainable_variables": trainable_variables, + "non_trainable_variables": non_trainable_variables, + "metrics_variables": metrics_variables, + } + + # Dispatch callbacks. This takes care of async dispatch. + callbacks.on_test_batch_end(step, logs) + + if self.stop_evaluating: + break + + # BAYESFLOW: average the metrics across all iterations + logs = self._mean_logs(logs, total_steps) + + # Reattach state back to model (if not already done by a callback). + self.jax_state_sync() + + # The jax spmd_mode is need for multi-process context, since the + # metrics values are replicated, and we don't want to do a all + # gather, and only need the local copy of the value. + with jax.spmd_mode("allow_all"): + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + self._jax_state = None + if not use_cached_eval_dataset: + # Only clear sharding if evaluate is not called from `fit`. + self._clear_jax_state_sharding() + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + def stateless_compute_metrics( self, trainable_variables: any, diff --git a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py index 8e2579325..2685bebba 100644 --- a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py +++ b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py @@ -3,13 +3,110 @@ from bayesflow.utils import filter_kwargs +from keras.src.backend.tensorflow.trainer import TFEpochIterator +from keras.src import callbacks as callbacks_module + class TensorFlowApproximator(keras.Model): + def _aggregate_logs(self, logs, step_logs): + if not logs: + return step_logs + + return keras.tree.map_structure(keras.ops.add, logs, step_logs) + + def _mean_logs(self, logs, total_steps): + if total_steps == 0: + return logs + + def _div(x): + return x / total_steps + + return keras.tree.map_structure(_div, logs) + # noinspection PyMethodOverriding def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]: # implemented by each respective architecture raise NotImplementedError + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + self._assert_compile_called("evaluate") + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = TFEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, + ) + + self._maybe_symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + total_steps = 0 + self.reset_metrics() + with epoch_iterator.catch_stop_iteration(): + for step, iterator in epoch_iterator: + total_steps += 1 + + callbacks.on_test_batch_begin(step) + + # BAYESFLOW: save into step_logs instead of overwriting logs + step_logs = self.test_function(iterator) + + # BAYESFLOW: aggregate the metrics across all iterations + logs = self._aggregate_logs(logs, step_logs) + + callbacks.on_test_batch_end(step, logs) + if self.stop_evaluating: + break + + # BAYESFLOW: average the metrics across all iterations + logs = self._mean_logs(logs, total_steps) + + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]: kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics) return self.compute_metrics(**kwargs) diff --git a/bayesflow/approximators/backend_approximators/torch_approximator.py b/bayesflow/approximators/backend_approximators/torch_approximator.py index 685bda8ec..d6fcc8b1d 100644 --- a/bayesflow/approximators/backend_approximators/torch_approximator.py +++ b/bayesflow/approximators/backend_approximators/torch_approximator.py @@ -3,13 +3,110 @@ from bayesflow.utils import filter_kwargs +from keras.src.backend.torch.trainer import TorchEpochIterator +from keras.src import callbacks as callbacks_module + class TorchApproximator(keras.Model): + def _aggregate_logs(self, logs, step_logs): + if not logs: + return step_logs + + return keras.tree.map_structure(keras.ops.add, logs, step_logs) + + def _mean_logs(self, logs, total_steps): + if total_steps == 0: + return logs + + def _div(x): + return x / total_steps + + return keras.tree.map_structure(_div, logs) + # noinspection PyMethodOverriding def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]: # implemented by each respective architecture raise NotImplementedError + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = TorchEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + self._symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + # Switch the torch Module back to testing mode. + self.eval() + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + total_steps = 0 + self.reset_metrics() + for step, data in epoch_iterator: + total_steps += 1 + + callbacks.on_test_batch_begin(step) + + # BAYESFLOW: save into step_logs instead of overwriting logs + step_logs = self.test_function(data) + + # BAYESFLOW: aggregate the metrics across all iterations + logs = self._aggregate_logs(logs, step_logs) + + callbacks.on_test_batch_end(step, logs) + if self.stop_evaluating: + break + + # BAYESFLOW: average the metrics across all iterations + logs = self._mean_logs(logs, total_steps) + + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + def test_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]: kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics) return self.compute_metrics(**kwargs) From cf3f397854d21af8b1eddc4c4a85476c198fd861 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Wed, 21 May 2025 16:10:46 -0400 Subject: [PATCH 2/5] use up-to-date keras code --- .../backend_approximators/jax_approximator.py | 53 +++++++++---------- .../tensorflow_approximator.py | 45 ++++++++-------- .../torch_approximator.py | 46 ++++++++-------- 3 files changed, 69 insertions(+), 75 deletions(-) diff --git a/bayesflow/approximators/backend_approximators/jax_approximator.py b/bayesflow/approximators/backend_approximators/jax_approximator.py index 9ce7ae072..756c984e6 100644 --- a/bayesflow/approximators/backend_approximators/jax_approximator.py +++ b/bayesflow/approximators/backend_approximators/jax_approximator.py @@ -8,21 +8,6 @@ class JAXApproximator(keras.Model): - def _aggregate_logs(self, logs, step_logs): - if not logs: - return step_logs - - return keras.tree.map_structure(keras.ops.add, logs, step_logs) - - def _mean_logs(self, logs, total_steps): - if total_steps == 0: - return logs - - def _div(x): - return x / total_steps - - return keras.tree.map_structure(_div, logs) - # noinspection PyMethodOverriding def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]: # implemented by each respective architecture @@ -38,6 +23,7 @@ def evaluate( steps=None, callbacks=None, return_dict=False, + aggregate=True, **kwargs, ): self._assert_compile_called("evaluate") @@ -49,7 +35,8 @@ def evaluate( if use_cached_eval_dataset: epoch_iterator = self._eval_epoch_iterator else: - # Create an iterator that yields batches of input/target data. + # Create an iterator that yields batches of + # input/target data. epoch_iterator = JAXEpochIterator( x=x, y=y, @@ -82,12 +69,25 @@ def evaluate( total_steps = 0 self.reset_metrics() + def _aggregate_fn(_logs, _step_logs): + if not _logs: + return _step_logs + + return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) + + def _reduce_fn(_logs, _total_steps): + def _div(val): + return val / _total_steps + + return keras.tree.map_structure(_div, _logs) + self._jax_state_synced = True with epoch_iterator.catch_stop_iteration(): for step, iterator in epoch_iterator: - total_steps += 1 callbacks.on_test_batch_begin(step) + total_steps += 1 + if self._jax_state_synced: # The state may have been synced by a callback. state = self._get_jax_state( @@ -98,7 +98,6 @@ def evaluate( ) self._jax_state_synced = False - # BAYESFLOW: save into step_logs instead of overwriting logs step_logs, state = self.test_function(state, iterator) ( trainable_variables, @@ -106,8 +105,10 @@ def evaluate( metrics_variables, ) = state - # BAYESFLOW: aggregate the metrics across all iterations - logs = self._aggregate_logs(logs, step_logs) + if aggregate: + logs = _aggregate_fn(logs, step_logs) + else: + logs = step_logs # Setting _jax_state enables callbacks to force a state sync # if they need to. @@ -120,22 +121,18 @@ def evaluate( } # Dispatch callbacks. This takes care of async dispatch. - callbacks.on_test_batch_end(step, logs) + callbacks.on_test_batch_end(step, step_logs) if self.stop_evaluating: break - # BAYESFLOW: average the metrics across all iterations - logs = self._mean_logs(logs, total_steps) + if aggregate: + logs = _reduce_fn(logs, total_steps) # Reattach state back to model (if not already done by a callback). self.jax_state_sync() - # The jax spmd_mode is need for multi-process context, since the - # metrics values are replicated, and we don't want to do a all - # gather, and only need the local copy of the value. - with jax.spmd_mode("allow_all"): - logs = self._get_metrics_result_or_logs(logs) + logs = self._get_metrics_result_or_logs(logs) callbacks.on_test_end(logs) self._jax_state = None if not use_cached_eval_dataset: diff --git a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py index 2685bebba..ddbb96909 100644 --- a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py +++ b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py @@ -8,21 +8,6 @@ class TensorFlowApproximator(keras.Model): - def _aggregate_logs(self, logs, step_logs): - if not logs: - return step_logs - - return keras.tree.map_structure(keras.ops.add, logs, step_logs) - - def _mean_logs(self, logs, total_steps): - if total_steps == 0: - return logs - - def _div(x): - return x / total_steps - - return keras.tree.map_structure(_div, logs) - # noinspection PyMethodOverriding def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]: # implemented by each respective architecture @@ -38,6 +23,7 @@ def evaluate( steps=None, callbacks=None, return_dict=False, + aggregate=False, **kwargs, ): self._assert_compile_called("evaluate") @@ -81,24 +67,37 @@ def evaluate( logs = {} total_steps = 0 self.reset_metrics() + + def _aggregate_fn(_logs, _step_logs): + if not _logs: + return _step_logs + + return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) + + def _reduce_fn(_logs, _total_steps): + def _div(val): + return val / _total_steps + + return keras.tree.map_structure(_div, _logs) + with epoch_iterator.catch_stop_iteration(): for step, iterator in epoch_iterator: - total_steps += 1 - callbacks.on_test_batch_begin(step) + total_steps += 1 - # BAYESFLOW: save into step_logs instead of overwriting logs step_logs = self.test_function(iterator) - # BAYESFLOW: aggregate the metrics across all iterations - logs = self._aggregate_logs(logs, step_logs) + if aggregate: + logs = _aggregate_fn(logs, step_logs) + else: + logs = step_logs - callbacks.on_test_batch_end(step, logs) + callbacks.on_test_batch_end(step, step_logs) if self.stop_evaluating: break - # BAYESFLOW: average the metrics across all iterations - logs = self._mean_logs(logs, total_steps) + if aggregate: + logs = _reduce_fn(logs, total_steps) logs = self._get_metrics_result_or_logs(logs) callbacks.on_test_end(logs) diff --git a/bayesflow/approximators/backend_approximators/torch_approximator.py b/bayesflow/approximators/backend_approximators/torch_approximator.py index d6fcc8b1d..c153f4805 100644 --- a/bayesflow/approximators/backend_approximators/torch_approximator.py +++ b/bayesflow/approximators/backend_approximators/torch_approximator.py @@ -8,21 +8,6 @@ class TorchApproximator(keras.Model): - def _aggregate_logs(self, logs, step_logs): - if not logs: - return step_logs - - return keras.tree.map_structure(keras.ops.add, logs, step_logs) - - def _mean_logs(self, logs, total_steps): - if total_steps == 0: - return logs - - def _div(x): - return x / total_steps - - return keras.tree.map_structure(_div, logs) - # noinspection PyMethodOverriding def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]: # implemented by each respective architecture @@ -38,6 +23,7 @@ def evaluate( steps=None, callbacks=None, return_dict=False, + aggregate=False, **kwargs, ): # TODO: respect compiled trainable state @@ -82,23 +68,35 @@ def evaluate( logs = {} total_steps = 0 self.reset_metrics() - for step, data in epoch_iterator: - total_steps += 1 - callbacks.on_test_batch_begin(step) + def _aggregate_fn(_logs, _step_logs): + if not _logs: + return _step_logs + + return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) + + def _reduce_fn(_logs, _total_steps): + def _div(val): + return val / _total_steps + + return keras.tree.map_structure(_div, _logs) - # BAYESFLOW: save into step_logs instead of overwriting logs + for step, data in epoch_iterator: + callbacks.on_test_batch_begin(step) + total_steps += 1 step_logs = self.test_function(data) - # BAYESFLOW: aggregate the metrics across all iterations - logs = self._aggregate_logs(logs, step_logs) + if aggregate: + logs = _aggregate_fn(logs, step_logs) + else: + logs = step_logs - callbacks.on_test_batch_end(step, logs) + callbacks.on_test_batch_end(step, step_logs) if self.stop_evaluating: break - # BAYESFLOW: average the metrics across all iterations - logs = self._mean_logs(logs, total_steps) + if aggregate: + logs = _reduce_fn(logs, total_steps) logs = self._get_metrics_result_or_logs(logs) callbacks.on_test_end(logs) From c59d8ffeb442c2165740f52d0564e49f1e721a1d Mon Sep 17 00:00:00 2001 From: LarsKue Date: Wed, 21 May 2025 16:14:32 -0400 Subject: [PATCH 3/5] guard for total_steps = 0 --- .../approximators/backend_approximators/jax_approximator.py | 3 +++ .../backend_approximators/tensorflow_approximator.py | 3 +++ .../approximators/backend_approximators/torch_approximator.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/bayesflow/approximators/backend_approximators/jax_approximator.py b/bayesflow/approximators/backend_approximators/jax_approximator.py index 756c984e6..2d1dcba9b 100644 --- a/bayesflow/approximators/backend_approximators/jax_approximator.py +++ b/bayesflow/approximators/backend_approximators/jax_approximator.py @@ -76,6 +76,9 @@ def _aggregate_fn(_logs, _step_logs): return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) def _reduce_fn(_logs, _total_steps): + if total_steps == 0: + return _logs + def _div(val): return val / _total_steps diff --git a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py index ddbb96909..ef2b679cd 100644 --- a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py +++ b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py @@ -75,6 +75,9 @@ def _aggregate_fn(_logs, _step_logs): return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) def _reduce_fn(_logs, _total_steps): + if total_steps == 0: + return _logs + def _div(val): return val / _total_steps diff --git a/bayesflow/approximators/backend_approximators/torch_approximator.py b/bayesflow/approximators/backend_approximators/torch_approximator.py index c153f4805..7a4c44c47 100644 --- a/bayesflow/approximators/backend_approximators/torch_approximator.py +++ b/bayesflow/approximators/backend_approximators/torch_approximator.py @@ -76,6 +76,9 @@ def _aggregate_fn(_logs, _step_logs): return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) def _reduce_fn(_logs, _total_steps): + if total_steps == 0: + return _logs + def _div(val): return val / _total_steps From a3b37d492b5bb27a8d50adfeed5a89eb671e15fe Mon Sep 17 00:00:00 2001 From: LarsKue Date: Wed, 21 May 2025 16:14:32 -0400 Subject: [PATCH 4/5] guard for total_steps = 0 --- .../approximators/backend_approximators/jax_approximator.py | 3 +++ .../backend_approximators/tensorflow_approximator.py | 3 +++ .../approximators/backend_approximators/torch_approximator.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/bayesflow/approximators/backend_approximators/jax_approximator.py b/bayesflow/approximators/backend_approximators/jax_approximator.py index 756c984e6..4607fb694 100644 --- a/bayesflow/approximators/backend_approximators/jax_approximator.py +++ b/bayesflow/approximators/backend_approximators/jax_approximator.py @@ -76,6 +76,9 @@ def _aggregate_fn(_logs, _step_logs): return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) def _reduce_fn(_logs, _total_steps): + if _total_steps == 0: + return _logs + def _div(val): return val / _total_steps diff --git a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py index ddbb96909..7326c289b 100644 --- a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py +++ b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py @@ -75,6 +75,9 @@ def _aggregate_fn(_logs, _step_logs): return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) def _reduce_fn(_logs, _total_steps): + if _total_steps == 0: + return _logs + def _div(val): return val / _total_steps diff --git a/bayesflow/approximators/backend_approximators/torch_approximator.py b/bayesflow/approximators/backend_approximators/torch_approximator.py index c153f4805..6970979ab 100644 --- a/bayesflow/approximators/backend_approximators/torch_approximator.py +++ b/bayesflow/approximators/backend_approximators/torch_approximator.py @@ -76,6 +76,9 @@ def _aggregate_fn(_logs, _step_logs): return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) def _reduce_fn(_logs, _total_steps): + if _total_steps == 0: + return _logs + def _div(val): return val / _total_steps From 677c3631f6093440ba91837fcfe10aa2ed871798 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Wed, 21 May 2025 16:18:22 -0400 Subject: [PATCH 5/5] fix default aggregate=False --- .../backend_approximators/tensorflow_approximator.py | 2 +- .../approximators/backend_approximators/torch_approximator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py index 7326c289b..e97d04063 100644 --- a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py +++ b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py @@ -23,7 +23,7 @@ def evaluate( steps=None, callbacks=None, return_dict=False, - aggregate=False, + aggregate=True, **kwargs, ): self._assert_compile_called("evaluate") diff --git a/bayesflow/approximators/backend_approximators/torch_approximator.py b/bayesflow/approximators/backend_approximators/torch_approximator.py index 6970979ab..609ddfeeb 100644 --- a/bayesflow/approximators/backend_approximators/torch_approximator.py +++ b/bayesflow/approximators/backend_approximators/torch_approximator.py @@ -23,7 +23,7 @@ def evaluate( steps=None, callbacks=None, return_dict=False, - aggregate=False, + aggregate=True, **kwargs, ): # TODO: respect compiled trainable state