Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions bayesflow/approximators/backend_approximators/jax_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,148 @@

from bayesflow.utils import filter_kwargs

from keras.src.backend.jax.trainer import JAXEpochIterator
from keras.src import callbacks as callbacks_module


class JAXApproximator(keras.Model):
def _aggregate_logs(self, logs, step_logs):
if not logs:
return step_logs

return keras.tree.map_structure(keras.ops.add, logs, step_logs)

def _mean_logs(self, logs, total_steps):
if total_steps == 0:
return logs

Check warning on line 19 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L19

Added line #L19 was not covered by tests

def _div(x):
return x / total_steps

return keras.tree.map_structure(_div, logs)

# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
# implemented by each respective architecture
raise NotImplementedError

def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
self._assert_compile_called("evaluate")
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")

Check warning on line 47 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L47

Added line #L47 was not covered by tests

if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = JAXEpochIterator(

Check warning on line 53 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L53

Added line #L53 was not covered by tests
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)

self._symbolic_build(iterator=epoch_iterator)
epoch_iterator.reset()

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(

Check warning on line 68 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L68

Added line #L68 was not covered by tests
callbacks,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_test_function()
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
total_steps = 0
self.reset_metrics()

self._jax_state_synced = True
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
total_steps += 1
callbacks.on_test_batch_begin(step)

if self._jax_state_synced:
# The state may have been synced by a callback.
state = self._get_jax_state(
trainable_variables=True,
non_trainable_variables=True,
metrics_variables=True,
purge_model_variables=True,
)
self._jax_state_synced = False

# BAYESFLOW: save into step_logs instead of overwriting logs
step_logs, state = self.test_function(state, iterator)
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state

# BAYESFLOW: aggregate the metrics across all iterations
logs = self._aggregate_logs(logs, step_logs)

# Setting _jax_state enables callbacks to force a state sync
# if they need to.
self._jax_state = {
# I wouldn't recommend modifying non-trainable model state
# during evaluate(), but it's allowed.
"trainable_variables": trainable_variables,
"non_trainable_variables": non_trainable_variables,
"metrics_variables": metrics_variables,
}

# Dispatch callbacks. This takes care of async dispatch.
callbacks.on_test_batch_end(step, logs)

if self.stop_evaluating:
break

Check warning on line 126 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L126

Added line #L126 was not covered by tests

# BAYESFLOW: average the metrics across all iterations
logs = self._mean_logs(logs, total_steps)

# Reattach state back to model (if not already done by a callback).
self.jax_state_sync()

# The jax spmd_mode is need for multi-process context, since the
# metrics values are replicated, and we don't want to do a all
# gather, and only need the local copy of the value.
with jax.spmd_mode("allow_all"):
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)
self._jax_state = None
if not use_cached_eval_dataset:

Check warning on line 141 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L138-L141

Added lines #L138 - L141 were not covered by tests
# Only clear sharding if evaluate is not called from `fit`.
self._clear_jax_state_sharding()
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

Check warning on line 146 in bayesflow/approximators/backend_approximators/jax_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/jax_approximator.py#L143-L146

Added lines #L143 - L146 were not covered by tests

def stateless_compute_metrics(
self,
trainable_variables: any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,110 @@

from bayesflow.utils import filter_kwargs

from keras.src.backend.tensorflow.trainer import TFEpochIterator
from keras.src import callbacks as callbacks_module

Check warning on line 7 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L6-L7

Added lines #L6 - L7 were not covered by tests


class TensorFlowApproximator(keras.Model):
def _aggregate_logs(self, logs, step_logs):
if not logs:
return step_logs

Check warning on line 13 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L11-L13

Added lines #L11 - L13 were not covered by tests

return keras.tree.map_structure(keras.ops.add, logs, step_logs)

Check warning on line 15 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L15

Added line #L15 was not covered by tests

def _mean_logs(self, logs, total_steps):
if total_steps == 0:
return logs

Check warning on line 19 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L17-L19

Added lines #L17 - L19 were not covered by tests

def _div(x):
return x / total_steps

Check warning on line 22 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L21-L22

Added lines #L21 - L22 were not covered by tests

return keras.tree.map_structure(_div, logs)

Check warning on line 24 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L24

Added line #L24 was not covered by tests

# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
# implemented by each respective architecture
raise NotImplementedError

def evaluate(

Check warning on line 31 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L31

Added line #L31 was not covered by tests
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
self._assert_compile_called("evaluate")

Check warning on line 43 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L43

Added line #L43 was not covered by tests
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")

Check warning on line 47 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L45-L47

Added lines #L45 - L47 were not covered by tests

if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator

Check warning on line 50 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L49-L50

Added lines #L49 - L50 were not covered by tests
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = TFEpochIterator(

Check warning on line 53 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L53

Added line #L53 was not covered by tests
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
distribute_strategy=self.distribute_strategy,
steps_per_execution=self.steps_per_execution,
)

self._maybe_symbolic_build(iterator=epoch_iterator)
epoch_iterator.reset()

Check warning on line 65 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L64-L65

Added lines #L64 - L65 were not covered by tests

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(

Check warning on line 69 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L68-L69

Added lines #L68 - L69 were not covered by tests
callbacks,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)

self.make_test_function()
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
total_steps = 0
self.reset_metrics()
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
total_steps += 1

Check warning on line 86 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L78-L86

Added lines #L78 - L86 were not covered by tests

callbacks.on_test_batch_begin(step)

Check warning on line 88 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L88

Added line #L88 was not covered by tests

# BAYESFLOW: save into step_logs instead of overwriting logs
step_logs = self.test_function(iterator)

Check warning on line 91 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L91

Added line #L91 was not covered by tests

# BAYESFLOW: aggregate the metrics across all iterations
logs = self._aggregate_logs(logs, step_logs)

Check warning on line 94 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L94

Added line #L94 was not covered by tests

callbacks.on_test_batch_end(step, logs)
if self.stop_evaluating:
break

Check warning on line 98 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L96-L98

Added lines #L96 - L98 were not covered by tests

# BAYESFLOW: average the metrics across all iterations
logs = self._mean_logs(logs, total_steps)

Check warning on line 101 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L101

Added line #L101 was not covered by tests

logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)

Check warning on line 104 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L103-L104

Added lines #L103 - L104 were not covered by tests

if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

Check warning on line 108 in bayesflow/approximators/backend_approximators/tensorflow_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/tensorflow_approximator.py#L106-L108

Added lines #L106 - L108 were not covered by tests

def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
return self.compute_metrics(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,110 @@

from bayesflow.utils import filter_kwargs

from keras.src.backend.torch.trainer import TorchEpochIterator
from keras.src import callbacks as callbacks_module

Check warning on line 7 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L6-L7

Added lines #L6 - L7 were not covered by tests


class TorchApproximator(keras.Model):
def _aggregate_logs(self, logs, step_logs):
if not logs:
return step_logs

Check warning on line 13 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L11-L13

Added lines #L11 - L13 were not covered by tests

return keras.tree.map_structure(keras.ops.add, logs, step_logs)

Check warning on line 15 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L15

Added line #L15 was not covered by tests

def _mean_logs(self, logs, total_steps):
if total_steps == 0:
return logs

Check warning on line 19 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L17-L19

Added lines #L17 - L19 were not covered by tests

def _div(x):
return x / total_steps

Check warning on line 22 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L21-L22

Added lines #L21 - L22 were not covered by tests

return keras.tree.map_structure(_div, logs)

Check warning on line 24 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L24

Added line #L24 was not covered by tests

# noinspection PyMethodOverriding
def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]:
# implemented by each respective architecture
raise NotImplementedError

def evaluate(

Check warning on line 31 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L31

Added line #L31 was not covered by tests
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")

Check warning on line 46 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L44-L46

Added lines #L44 - L46 were not covered by tests

if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator

Check warning on line 49 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L48-L49

Added lines #L48 - L49 were not covered by tests
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = TorchEpochIterator(

Check warning on line 52 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L52

Added line #L52 was not covered by tests
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)

self._symbolic_build(iterator=epoch_iterator)
epoch_iterator.reset()

Check warning on line 63 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L62-L63

Added lines #L62 - L63 were not covered by tests

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(

Check warning on line 67 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L66-L67

Added lines #L66 - L67 were not covered by tests
callbacks,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)

# Switch the torch Module back to testing mode.
self.eval()

Check warning on line 77 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L77

Added line #L77 was not covered by tests

self.make_test_function()
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
total_steps = 0
self.reset_metrics()
for step, data in epoch_iterator:
total_steps += 1

Check warning on line 86 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L79-L86

Added lines #L79 - L86 were not covered by tests

callbacks.on_test_batch_begin(step)

Check warning on line 88 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L88

Added line #L88 was not covered by tests

# BAYESFLOW: save into step_logs instead of overwriting logs
step_logs = self.test_function(data)

Check warning on line 91 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L91

Added line #L91 was not covered by tests

# BAYESFLOW: aggregate the metrics across all iterations
logs = self._aggregate_logs(logs, step_logs)

Check warning on line 94 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L94

Added line #L94 was not covered by tests

callbacks.on_test_batch_end(step, logs)
if self.stop_evaluating:
break

Check warning on line 98 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L96-L98

Added lines #L96 - L98 were not covered by tests

# BAYESFLOW: average the metrics across all iterations
logs = self._mean_logs(logs, total_steps)

Check warning on line 101 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L101

Added line #L101 was not covered by tests

logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)

Check warning on line 104 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L103-L104

Added lines #L103 - L104 were not covered by tests

if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

Check warning on line 108 in bayesflow/approximators/backend_approximators/torch_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/backend_approximators/torch_approximator.py#L106-L108

Added lines #L106 - L108 were not covered by tests

def test_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
return self.compute_metrics(**kwargs)
Expand Down
Loading