Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import MutableSequence, Sequence, Mapping
from collections.abc import Callable, MutableSequence, Sequence, Mapping

import numpy as np

Expand All @@ -24,6 +24,7 @@
NumpyTransform,
OneHot,
Rename,
SerializableCustomTransform,
Sqrt,
Standardize,
ToArray,
Expand Down Expand Up @@ -274,6 +275,88 @@ def apply(
self.transforms.append(transform)
return self

def apply_serializable(
self,
include: str | Sequence[str] = None,
*,
forward: Callable[[np.ndarray, ...], np.ndarray],
inverse: Callable[[np.ndarray, ...], np.ndarray],
predicate: Predicate = None,
exclude: str | Sequence[str] = None,
**kwargs,
):
"""Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter.

Parameters
----------
forward : function, no lambda
Registered serializable function to transform the data in the forward pass.
For the adapter to be serializable, this function has to be serializable
as well (see Notes). Therefore, only proper functions and no lambda
functions can be used here.
inverse : function, no lambda
Registered serializable function to transform the data in the inverse pass.
For the adapter to be serializable, this function has to be serializable
as well (see Notes). Therefore, only proper functions and no lambda
functions can be used here.
predicate : Predicate, optional
Function that indicates which variables should be transformed.
include : str or Sequence of str, optional
Names of variables to include in the transform.
exclude : str or Sequence of str, optional
Names of variables to exclude from the transform.
**kwargs : dict
Additional keyword arguments passed to the transform.

Raises
------
ValueError
When the provided functions are not registered serializable functions.

Notes
-----
Important: The forward and inverse functions have to be registered with Keras.
To do so, use the `@keras.saving.register_keras_serializable` decorator.
They must also be registered (and identical) when loading the adapter
at a later point in time.

Examples
--------

The example below shows how to use the
`keras.saving.register_keras_serializable` decorator to
register functions with Keras. Note that for this simple
example, one usually would use the simpler :py:meth:`apply`
method.

>>> import keras
>>>
>>> @keras.saving.register_keras_serializable("custom")
>>> def forward_fn(x):
>>> return x**2
>>>
>>> @keras.saving.register_keras_serializable("custom")
>>> def inverse_fn(x):
>>> return x**0.5
>>>
>>> adapter = bf.Adapter().apply_serializable(
>>> "x",
>>> forward=forward_fn,
>>> inverse=inverse_fn,
>>> )
"""
transform = FilterTransform(
transform_constructor=SerializableCustomTransform,
predicate=predicate,
include=include,
exclude=exclude,
forward=forward,
inverse=inverse,
**kwargs,
)
self.transforms.append(transform)
return self

def as_set(self, keys: str | Sequence[str]):
"""Append an :py:class:`~transforms.AsSet` transform to the adapter.

Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .one_hot import OneHot
from .rename import Rename
from .scale import Scale
from .serializable_custom_transform import SerializableCustomTransform
from .shift import Shift
from .sqrt import Sqrt
from .standardize import Standardize
Expand Down
183 changes: 183 additions & 0 deletions bayesflow/adapters/transforms/serializable_custom_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from collections.abc import Callable
import numpy as np
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
get_registered_name,
get_registered_object,
)
from .elementwise_transform import ElementwiseTransform
from ...utils import filter_kwargs
import inspect


@serializable(package="bayesflow.adapters")
class SerializableCustomTransform(ElementwiseTransform):
"""
Transforms a parameter using a pair of registered serializable forward and inverse functions.

Parameters
----------
forward : function, no lambda
Registered serializable function to transform the data in the forward pass.
For the adapter to be serializable, this function has to be serializable
as well (see Notes). Therefore, only proper functions and no lambda
functions can be used here.
inverse : function, no lambda
Function to transform the data in the inverse pass.
For the adapter to be serializable, this function has to be serializable
as well (see Notes). Therefore, only proper functions and no lambda
functions can be used here.

Raises
------
ValueError
When the provided functions are not registered serializable functions.

Notes
-----
Important: The forward and inverse functions have to be registered with Keras.
To do so, use the `@keras.saving.register_keras_serializable` decorator.
They must also be registered (and identical) when loading the adapter
at a later point in time.

"""

def __init__(
self,
*,
forward: Callable[[np.ndarray, ...], np.ndarray],
inverse: Callable[[np.ndarray, ...], np.ndarray],
):
super().__init__()

self._check_serializable(forward, label="forward")
self._check_serializable(inverse, label="inverse")
self._forward = forward
self._inverse = inverse

@classmethod
def _check_serializable(cls, function, label=""):
GENERAL_EXAMPLE_CODE = (
"The example code below shows the structure of a correctly decorated function:\n\n"
"```\n"
"import keras\n\n"
"@keras.saving.register_keras_serializable('custom')\n"
f"def my_{label}(...):\n"
" [your code goes here...]\n"
"```\n"
)
if function is None:
raise TypeError(
f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}"
)
registered_name = get_registered_name(function)
# check if function is a lambda function
if registered_name == "<lambda>":
raise ValueError(
f"The provided function for '{label}' is a lambda function, "
"which cannot be serialized. "
"Please provide a registered serializable function by using the "
"@keras.saving.register_keras_serializable decorator."
f"\n{GENERAL_EXAMPLE_CODE}"
)
if inspect.ismethod(function):
raise ValueError(
f"The provided value for '{label}' is a method, not a function. "
"Methods cannot be serialized separately from their classes. "
"Please provide a registered serializable function instead by "
"moving the functionality to a function (i.e., outside of the class) and "
"using the @keras.saving.register_keras_serializable decorator."
f"\n{GENERAL_EXAMPLE_CODE}"
)
registered_object_for_name = get_registered_object(registered_name)
if registered_object_for_name is None:
try:
source_max_lines = 5
function_source_code = inspect.getsource(function).split("\n")
if len(function_source_code) > source_max_lines:
function_source_code = function_source_code[:source_max_lines] + [" [...]"]

example_code = "For your provided function, this would look like this:\n\n"
example_code += "\n".join(
["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"]
+ function_source_code
+ ["```"]
)
except OSError:
example_code = GENERAL_EXAMPLE_CODE
raise ValueError(
f"The provided function for '{label}' is not registered with Keras.\n"
"Please register the function using the "
"@keras.saving.register_keras_serializable decorator.\n"
f"{example_code}"
)
if registered_object_for_name is not function:
raise ValueError(
f"The provided function for '{label}' does not match the function "
f"registered under its name '{registered_name}'. "
f"(registered function: {registered_object_for_name}, provided function: {function}). "
)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform":
if get_registered_object(config["forward"]["config"], custom_objects) is None:
provided_function_msg = ""
if config["_forward_source_code"]:
provided_function_msg = (
f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```"
)
raise TypeError(
"\n\nPLEASE READ HERE:\n"
"-----------------\n"
"The forward function that was provided as `forward` "
"is not registered with Keras, making deserialization impossible. "
f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original "
"function before loading your model."
f"{provided_function_msg}"
)
if get_registered_object(config["inverse"]["config"], custom_objects) is None:
provided_function_msg = ""
if config["_inverse_source_code"]:
provided_function_msg = (
f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```"
)
raise TypeError(
"\n\nPLEASE READ HERE:\n"
"-----------------\n"
"The inverse function that was provided as `inverse` "
"is not registered with Keras, making deserialization impossible. "
f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original "
"function before loading your model."
f"{provided_function_msg}"
)
forward = deserialize(config["forward"], custom_objects)
inverse = deserialize(config["inverse"], custom_objects)
return cls(
forward=forward,
inverse=inverse,
)

def get_config(self) -> dict:
forward_source_code = inverse_source_code = None
try:
forward_source_code = inspect.getsource(self._forward)
inverse_source_code = inspect.getsource(self._inverse)
except OSError:
pass
return {
"forward": serialize(self._forward),
"inverse": serialize(self._inverse),
"_forward_source_code": forward_source_code,
"_inverse_source_code": inverse_source_code,
}

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
# filter kwargs so that other transform args like batch_size, strict, ... are not passed through
kwargs = filter_kwargs(kwargs, self._forward)
return self._forward(data, **kwargs)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
kwargs = filter_kwargs(kwargs, self._inverse)
return self._inverse(data, **kwargs)
3 changes: 2 additions & 1 deletion bayesflow/experimental/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
vjp,
serialize_value_or_type,
deserialize_value_or_type,
weighted_mean,
)

from bayesflow.networks import InferenceNetwork
Expand Down Expand Up @@ -240,6 +241,6 @@ def decode(z):
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)

losses = maximum_likelihood_loss + self.beta * reconstruction_loss
loss = self.aggregate(losses, sample_weight)
loss = weighted_mean(losses, sample_weight)

return base_metrics | {"loss": loss}
10 changes: 8 additions & 2 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type, weighted_sum
from bayesflow.utils import (
find_network,
keras_kwargs,
serialize_value_or_type,
deserialize_value_or_type,
weighted_mean,
)


from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -331,6 +337,6 @@ def compute_metrics(

# Pseudo-huber loss, see [2], Section 3.3
loss = lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber)
loss = weighted_sum(loss, sample_weight)
loss = weighted_mean(loss, sample_weight)

return base_metrics | {"loss": loss}
4 changes: 2 additions & 2 deletions bayesflow/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
keras_kwargs,
serialize_value_or_type,
deserialize_value_or_type,
weighted_sum,
weighted_mean,
)

from .actnorm import ActNorm
Expand Down Expand Up @@ -167,6 +167,6 @@ def compute_metrics(
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)

z, log_density = self(x, conditions=conditions, inverse=False, density=True)
loss = weighted_sum(-log_density, sample_weight)
loss = weighted_mean(-log_density, sample_weight)

return base_metrics | {"loss": loss}
4 changes: 2 additions & 2 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
optimal_transport,
serialize_value_or_type,
deserialize_value_or_type,
weighted_sum,
weighted_mean,
)
from ..inference_network import InferenceNetwork

Expand Down Expand Up @@ -260,6 +260,6 @@ def compute_metrics(
predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training")

loss = self.loss_fn(target_velocity, predicted_velocity)
loss = weighted_sum(loss, sample_weight)
loss = weighted_mean(loss, sample_weight)

return base_metrics | {"loss": loss}
3 changes: 2 additions & 1 deletion bayesflow/scores/normed_difference_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils import weighted_mean

from .scoring_rule import ScoringRule

Expand Down Expand Up @@ -55,7 +56,7 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
"""
estimates = estimates["value"]
scores = keras.ops.absolute(estimates - targets) ** self.k
score = self.aggregate(scores, weights)
score = weighted_mean(scores, weights)
return score

def get_config(self):
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/scores/parametric_distribution_score.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import weighted_mean

from .scoring_rule import ScoringRule

Expand Down Expand Up @@ -29,5 +30,5 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
:math:`S(\hat p_\phi, \theta; k) = -\log(\hat p_\phi(\theta))`
"""
scores = -self.log_prob(x=targets, **estimates)
score = self.aggregate(scores, weights)
score = weighted_mean(scores, weights)
return score
Loading