Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:

- name: Install Dependencies
run: |
pip install -U pip setuptools wheel
python -m pip install -U pip setuptools wheel
pip install .[test]

- name: Install JAX
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
keras_utils,
logging,
numpy_utils,
serialization,
)

from .callbacks import detailed_loss_callback
Expand Down Expand Up @@ -104,4 +105,4 @@

from ._docs import _add_imports_to_all

_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils"])
_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization"])
1 change: 1 addition & 0 deletions bayesflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def allow_args(fn: Decorator) -> Decorator:
def wrapper(f: Fn) -> Fn: ...
@overload
def wrapper(*fargs: any, **fkwargs: any) -> Fn: ...
@wraps(fn)
def wrapper(*fargs: any, **fkwargs: any) -> Fn:
if len(fargs) == 1 and not fkwargs and callable(fargs[0]):
# called without arguments
Expand Down
89 changes: 79 additions & 10 deletions bayesflow/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,36 +92,85 @@
return updated_config


def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs):
"""Deserialize an object serialized with :py:func:`serialize`.

Wrapper function around `keras.saving.deserialize_keras_object` to enable deserialization of
classes.

Parameters
----------
config : dict
Python dict describing the object.
custom_objects : dict, optional
Python dict containing a mapping between custom object names and the corresponding
classes or functions. Forwarded to `keras.saving.deserialize_keras_object`.
safe_mode : bool, optional
Boolean, whether to disallow unsafe lambda deserialization. When safe_mode=False,
loading an object has the potential to trigger arbitrary code execution. This argument
is only applicable to the Keras v3 model format. Defaults to True.
Forwarded to `keras.saving.deserialize_keras_object`.

Returns
-------
obj :
The object described by the config dictionary.

Raises
------
ValueError
If a type in the config can not be deserialized.

See Also
--------
serialize
"""
with monkey_patch(deserialize_keras_object, deserialize) as original_deserialize:
if isinstance(obj, str) and obj.startswith(_type_prefix):
if isinstance(config, str) and config.startswith(_type_prefix):
# we marked this as a type during serialization
obj = obj[len(_type_prefix) :]
config = config[len(_type_prefix) :]
tp = keras.saving.get_registered_object(
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
obj,
config,
custom_objects=custom_objects,
module_objects=np.__dict__ | builtins.__dict__,
)
if tp is None:
raise ValueError(
f"Could not deserialize type {obj!r}. Make sure it is registered with "
f"Could not deserialize type {config!r}. Make sure it is registered with "
f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`."
)
return tp
if inspect.isclass(obj):
if inspect.isclass(config):
# add this base case since keras does not cover it
return obj
return config

Check warning on line 146 in bayesflow/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/serialization.py#L146

Added line #L146 was not covered by tests

obj = original_deserialize(obj, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs)
obj = original_deserialize(config, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs)

return obj


@allow_args
def serializable(cls, package=None, name=None):
def serializable(cls, package: str | None = None, name: str | None = None):
"""Register class as Keras serialize.

Wrapper function around `keras.saving.register_keras_serializable` to automatically
set the `package` and `name` arguments.

Parameters
----------
cls : type
The class to register.
package : str, optional
`package` argument forwarded to `keras.saving.register_keras_serializable`.
If None is provided, the package is automatically inferred using the __name__
attribute of the module the class resides in.
name : str, optional
`name` argument forwarded to `keras.saving.register_keras_serializable`.
If None is provided, the classe's __name__ attribute is used.
"""
if package is None:
frame = sys._getframe(1)
frame = sys._getframe(2)
g = frame.f_globals
package = g.get("__name__", "bayesflow")

Expand All @@ -133,6 +182,26 @@


def serialize(obj):
"""Serialize an object using Keras.

Wrapper function around `keras.saving.serialize_keras_object`, which adds the
ability to serialize classes.

Parameters
----------
object : Keras serializable object, or class
The object to serialize

Returns
-------
config : dict
A python dict that represents the object. The python dict can be deserialized via
:py:func:`deserialize`.

See Also
--------
deserialize
"""
if isinstance(obj, (tuple, list, dict)):
return keras.tree.map_structure(serialize, obj)
elif inspect.isclass(obj):
Expand Down
90 changes: 13 additions & 77 deletions docsrc/source/development/index.md
Original file line number Diff line number Diff line change
@@ -1,87 +1,23 @@
# Patterns & Caveats
# Developer Documentation

**Note**: This document is part of BayesFlow's developer documentation, and
**Attention:** You are looking BayesFlow's developer documentation, which is
aimed at people who want to extend or improve BayesFlow. For user documentation,
please refer to the examples and the public API documentation.
please refer to the {doc}`../examples` and the {doc}`../api/bayesflow`.

## Introduction

From version 2 on, BayesFlow is built on [Keras](https://keras.io/) v3, which
allows writing machine learning pipelines that run in JAX, TensorFlow and PyTorch.
By using functionality provided by Keras, and extending it with backend-specific
code where necessary, we aim to build BayesFlow in a backend-agnostic fashion as
well.

As Keras is built upon three different backend, each with different functionality
and design decisions, it has its own quirks and compromises. This documents
outlines some of them, along with the design decisions and programming patterns
we use to counter them.

This document is work in progress, so if you read through the code base and
This section is work in progress, so if you read through the code base and
encounter something that looks odd, but shows up in multiple places, please open
an issue so that we can add it here. Also, if you introduce a new pattern that
others will have to use in the future as well, please document it here, along
with some background information on why it is necessary and how to use it in
practice.

## Privileged `training` argument in the `call()` method cannot be passed via `kwargs`

For layers that have different behavior at training and inference time (e.g.,
dropout or batch normalization layers), a boolean `training` argument can be
exposed, see [this section of the Keras documentation](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#privileged-training-argument-in-the-call-method).
If we want to pass this manually, we have to do so explicitly and not as part
of a set of keyword arguments via `**kwargs`.

@Lars: Maybe you can add more details on what is going on behind the scenes.

## Serialization

Serialization deals with the problem of storing objects to disk, and loading
them at a later point in time. This is straight-forward for data structures like
numpy arrays, but for classes with custom behavior, like approximators or neural
network layers, it is somewhat more complex.

Please refer to the Keras guide [Save, serialize, and export models](https://keras.io/guides/serialization_and_saving/)
for an introduction, and [Customizing Saving and Serialization](https://keras.io/guides/customizing_saving_and_serialization/)
for advanced concepts.

The basic idea is: by storing the arguments of the constructor of a class
(i.e., the arguments of the `__init__` function), we can later construct an
object identical to the one we have stored, except for the weights.
As the structure is identical, we can then map the stored weights to the newly
constructed object. The caveat is that all arguments have to be either basic
Python objects (like int, float, string, bool, ...) or themselves serializable.
If they are not, we have to manually specify how to serialize them, and how to
load them later on.

### Registering classes as serializable

TODO

### Serialization of custom types

In BayesFlow, we often encounter situations where we do not want to pass a
specific object (e.g., an MPL of a certain size), but we want to pass its type
(MLP) and the arguments to construct it. With the type and the arguments, we can
then construct multiple instances of the network in different places, for example
as the network inside a coupling block.

Unfortunately, `type` is not Keras serializable, so we have to serialize those
arguments manually. To complicate matters further, we also allow passing a string
instead of a type, which is then used to select the correct type.

To make it more concrete, we look at the `CouplingFlow` class, which takes the
argument `subnet` that provide the type of the subnet. It is either a
string (e.g., `"mlp"`) or a class (e.g., `bayesflow.networks.MLP`). In the first
case, we can just store the value and load it, in the latter case, we first have
to convert the type to a string that we can later convert back into a type.

We provide two helper functions that can deal with both cases:
`bayesflow.utils.serialize_value_or_type(config, name, obj)` and
`bayesflow.utils.deserialize_value_or_type(config, name)`.
In `get_config`, we use the first to store the object, whereas we use the
latter in `from_config` to load it again.
```{toctree}
:maxdepth: 1
:titlesonly:
:numbered:

As we need all arguments to `__init__` in `get_config`, it can make sense to
build a `config` dictionary in `__init__` already, which can then be stored when
`get_config` is called. Take a look at `CouplingFlow` for an example of that.
introduction
pitfalls
stages
serialization
```
12 changes: 12 additions & 0 deletions docsrc/source/development/introduction.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Introduction

From version 2 on, BayesFlow is built on [Keras3](https://keras.io/), which
allows writing machine learning pipelines that run in JAX, TensorFlow and PyTorch.
By using functionality provided by Keras, and extending it with backend-specific
code where necessary, we aim to build BayesFlow in a backend-agnostic fashion as
well.

As Keras is built upon three different backends, each with different functionality
and design decisions, it comes with its own quirks and compromises. The following documents
outline some of them, along with the design decisions and programming patterns
we use to counter them.
13 changes: 13 additions & 0 deletions docsrc/source/development/pitfalls.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Potential Pitfalls

This document covers things we have learned during development that might cause problems or hard to find bugs.

## Privileged `training` argument in the `call()` method cannot be passed via `kwargs`

For layers that have different behavior at training and inference time (e.g.,
dropout or batch normalization layers), a boolean `training` argument can be
exposed, see [this section of the Keras documentation](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#privileged-training-argument-in-the-call-method).
If we want to pass this manually, we have to do so explicitly and not as part
of a set of keyword arguments via `**kwargs`.

@Lars: Maybe you can add more details on what is going on behind the scenes.
28 changes: 28 additions & 0 deletions docsrc/source/development/serialization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Serialization: Enable Model Saving & Loading

Serialization deals with the problem of storing objects to disk, and loading them at a later point in time.
This is straight-forward for data structures like numpy arrays, but for classes with custom behavior it is somewhat more complex.

Please refer to the Keras guide [Save, serialize, and export models](https://keras.io/guides/serialization_and_saving/) for an introduction, and [Customizing Saving and Serialization](https://keras.io/guides/customizing_saving_and_serialization/) for advanced concepts.

The basic idea is: by storing the arguments of the constructor of a class (i.e., the arguments of the `__init__` function), we can later construct an object similar to the one we have stored, except for the weights and other stateful content.
As the structure is identical, we can then map the stored weights to the newly constructed object.
The caveat is that all arguments have to be either basic Python objects (like int, float, string, bool, ...) or themselves serializable.
If they are not, we have to manually specify how to serialize them, and how to load them later on.
One important example is that types are not serializable.
As we want/need to pass them in some places, we have to resort to some custom behavior, that is described below.

## Serialization Utilities

BayesFlows serialization utilities can be found in the {py:mod}`~bayesflow.utils.serialization` module.
We mainly provide three convenience functions:

- The {py:func}`~bayesflow.utils.serialization.serializable` decorator wraps the `keras.saving.register_keras_serializable` function to provide automatic `package` and `name` arguments.
- The {py:func}`~bayesflow.utils.serialization.serialize` function, which adds support for serializing classes.
- Its counterpart {py:func}`~bayesflow.utils.serialization.deserialize`, adds support to deserialize classes.

_Note: The `(de)serialize_value_or_type` functions are made obsolete by the functions given above and will probably be deprecated soon._

## Usage

To use the adapted serialization functions, you have to use them in the `get_config` and `from_config` method. Please refer to existing classes in the library for usage examples.
8 changes: 8 additions & 0 deletions docsrc/source/development/stages.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Stages

To keep track of the phase each functionality is called in, we provide a `stage` parameter.
There are three stages:

- `training`: The stage to train approximator (and related stateful objects, like the adapter)
- `validation`: Identical setting to `training`, but calls in this stage should _not_ change the approximator
- `inference`: Calls in this change should not change the approximator. In addition, the input structure might be different compared to the training phase. For example for sampling, we only provide `summary_conditions` and `inference_conditions`, but not the `inference_variables`, which we want to infer.