diff --git a/bayesflow/approximators/backend_approximators/jax_approximator.py b/bayesflow/approximators/backend_approximators/jax_approximator.py index a30be99c8..4607fb694 100644 --- a/bayesflow/approximators/backend_approximators/jax_approximator.py +++ b/bayesflow/approximators/backend_approximators/jax_approximator.py @@ -3,6 +3,9 @@ from bayesflow.utils import filter_kwargs +from keras.src.backend.jax.trainer import JAXEpochIterator +from keras.src import callbacks as callbacks_module + class JAXApproximator(keras.Model): # noinspection PyMethodOverriding @@ -10,6 +13,138 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]: # implemented by each respective architecture raise NotImplementedError + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + aggregate=True, + **kwargs, + ): + self._assert_compile_called("evaluate") + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of + # input/target data. + epoch_iterator = JAXEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + self._symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + self._record_training_state_sharding_spec() + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + total_steps = 0 + self.reset_metrics() + + def _aggregate_fn(_logs, _step_logs): + if not _logs: + return _step_logs + + return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) + + def _reduce_fn(_logs, _total_steps): + if _total_steps == 0: + return _logs + + def _div(val): + return val / _total_steps + + return keras.tree.map_structure(_div, _logs) + + self._jax_state_synced = True + with epoch_iterator.catch_stop_iteration(): + for step, iterator in epoch_iterator: + callbacks.on_test_batch_begin(step) + + total_steps += 1 + + if self._jax_state_synced: + # The state may have been synced by a callback. + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + metrics_variables=True, + purge_model_variables=True, + ) + self._jax_state_synced = False + + step_logs, state = self.test_function(state, iterator) + ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = state + + if aggregate: + logs = _aggregate_fn(logs, step_logs) + else: + logs = step_logs + + # Setting _jax_state enables callbacks to force a state sync + # if they need to. + self._jax_state = { + # I wouldn't recommend modifying non-trainable model state + # during evaluate(), but it's allowed. + "trainable_variables": trainable_variables, + "non_trainable_variables": non_trainable_variables, + "metrics_variables": metrics_variables, + } + + # Dispatch callbacks. This takes care of async dispatch. + callbacks.on_test_batch_end(step, step_logs) + + if self.stop_evaluating: + break + + if aggregate: + logs = _reduce_fn(logs, total_steps) + + # Reattach state back to model (if not already done by a callback). + self.jax_state_sync() + + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + self._jax_state = None + if not use_cached_eval_dataset: + # Only clear sharding if evaluate is not called from `fit`. + self._clear_jax_state_sharding() + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + def stateless_compute_metrics( self, trainable_variables: any, diff --git a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py index 8e2579325..e97d04063 100644 --- a/bayesflow/approximators/backend_approximators/tensorflow_approximator.py +++ b/bayesflow/approximators/backend_approximators/tensorflow_approximator.py @@ -3,6 +3,9 @@ from bayesflow.utils import filter_kwargs +from keras.src.backend.tensorflow.trainer import TFEpochIterator +from keras.src import callbacks as callbacks_module + class TensorFlowApproximator(keras.Model): # noinspection PyMethodOverriding @@ -10,6 +13,102 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]: # implemented by each respective architecture raise NotImplementedError + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + aggregate=True, + **kwargs, + ): + self._assert_compile_called("evaluate") + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = TFEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, + ) + + self._maybe_symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + total_steps = 0 + self.reset_metrics() + + def _aggregate_fn(_logs, _step_logs): + if not _logs: + return _step_logs + + return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) + + def _reduce_fn(_logs, _total_steps): + if _total_steps == 0: + return _logs + + def _div(val): + return val / _total_steps + + return keras.tree.map_structure(_div, _logs) + + with epoch_iterator.catch_stop_iteration(): + for step, iterator in epoch_iterator: + callbacks.on_test_batch_begin(step) + total_steps += 1 + + step_logs = self.test_function(iterator) + + if aggregate: + logs = _aggregate_fn(logs, step_logs) + else: + logs = step_logs + + callbacks.on_test_batch_end(step, step_logs) + if self.stop_evaluating: + break + + if aggregate: + logs = _reduce_fn(logs, total_steps) + + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]: kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics) return self.compute_metrics(**kwargs) diff --git a/bayesflow/approximators/backend_approximators/torch_approximator.py b/bayesflow/approximators/backend_approximators/torch_approximator.py index 685bda8ec..609ddfeeb 100644 --- a/bayesflow/approximators/backend_approximators/torch_approximator.py +++ b/bayesflow/approximators/backend_approximators/torch_approximator.py @@ -3,6 +3,9 @@ from bayesflow.utils import filter_kwargs +from keras.src.backend.torch.trainer import TorchEpochIterator +from keras.src import callbacks as callbacks_module + class TorchApproximator(keras.Model): # noinspection PyMethodOverriding @@ -10,6 +13,101 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]: # implemented by each respective architecture raise NotImplementedError + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + aggregate=True, + **kwargs, + ): + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = TorchEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + self._symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + # Switch the torch Module back to testing mode. + self.eval() + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + total_steps = 0 + self.reset_metrics() + + def _aggregate_fn(_logs, _step_logs): + if not _logs: + return _step_logs + + return keras.tree.map_structure(keras.ops.add, _logs, _step_logs) + + def _reduce_fn(_logs, _total_steps): + if _total_steps == 0: + return _logs + + def _div(val): + return val / _total_steps + + return keras.tree.map_structure(_div, _logs) + + for step, data in epoch_iterator: + callbacks.on_test_batch_begin(step) + total_steps += 1 + step_logs = self.test_function(data) + + if aggregate: + logs = _aggregate_fn(logs, step_logs) + else: + logs = step_logs + + callbacks.on_test_batch_end(step, step_logs) + if self.stop_evaluating: + break + + if aggregate: + logs = _reduce_fn(logs, total_steps) + + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + def test_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]: kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics) return self.compute_metrics(**kwargs)