diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index c3a438761..73af0c47c 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -1,4 +1,4 @@ -from collections.abc import MutableSequence, Sequence, Mapping +from collections.abc import Callable, MutableSequence, Sequence, Mapping import numpy as np @@ -24,6 +24,7 @@ NumpyTransform, OneHot, Rename, + SerializableCustomTransform, Sqrt, Standardize, ToArray, @@ -274,6 +275,88 @@ def apply( self.transforms.append(transform) return self + def apply_serializable( + self, + include: str | Sequence[str] = None, + *, + forward: Callable[[np.ndarray, ...], np.ndarray], + inverse: Callable[[np.ndarray, ...], np.ndarray], + predicate: Predicate = None, + exclude: str | Sequence[str] = None, + **kwargs, + ): + """Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter. + + Parameters + ---------- + forward : function, no lambda + Registered serializable function to transform the data in the forward pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + inverse : function, no lambda + Registered serializable function to transform the data in the inverse pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + predicate : Predicate, optional + Function that indicates which variables should be transformed. + include : str or Sequence of str, optional + Names of variables to include in the transform. + exclude : str or Sequence of str, optional + Names of variables to exclude from the transform. + **kwargs : dict + Additional keyword arguments passed to the transform. + + Raises + ------ + ValueError + When the provided functions are not registered serializable functions. + + Notes + ----- + Important: The forward and inverse functions have to be registered with Keras. + To do so, use the `@keras.saving.register_keras_serializable` decorator. + They must also be registered (and identical) when loading the adapter + at a later point in time. + + Examples + -------- + + The example below shows how to use the + `keras.saving.register_keras_serializable` decorator to + register functions with Keras. Note that for this simple + example, one usually would use the simpler :py:meth:`apply` + method. + + >>> import keras + >>> + >>> @keras.saving.register_keras_serializable("custom") + >>> def forward_fn(x): + >>> return x**2 + >>> + >>> @keras.saving.register_keras_serializable("custom") + >>> def inverse_fn(x): + >>> return x**0.5 + >>> + >>> adapter = bf.Adapter().apply_serializable( + >>> "x", + >>> forward=forward_fn, + >>> inverse=inverse_fn, + >>> ) + """ + transform = FilterTransform( + transform_constructor=SerializableCustomTransform, + predicate=predicate, + include=include, + exclude=exclude, + forward=forward, + inverse=inverse, + **kwargs, + ) + self.transforms.append(transform) + return self + def as_set(self, keys: str | Sequence[str]): """Append an :py:class:`~transforms.AsSet` transform to the adapter. diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index b3e95a494..81e9f665f 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -15,6 +15,7 @@ from .one_hot import OneHot from .rename import Rename from .scale import Scale +from .serializable_custom_transform import SerializableCustomTransform from .shift import Shift from .sqrt import Sqrt from .standardize import Standardize diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py new file mode 100644 index 000000000..75d588afd --- /dev/null +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -0,0 +1,183 @@ +from collections.abc import Callable +import numpy as np +from keras.saving import ( + deserialize_keras_object as deserialize, + register_keras_serializable as serializable, + serialize_keras_object as serialize, + get_registered_name, + get_registered_object, +) +from .elementwise_transform import ElementwiseTransform +from ...utils import filter_kwargs +import inspect + + +@serializable(package="bayesflow.adapters") +class SerializableCustomTransform(ElementwiseTransform): + """ + Transforms a parameter using a pair of registered serializable forward and inverse functions. + + Parameters + ---------- + forward : function, no lambda + Registered serializable function to transform the data in the forward pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + inverse : function, no lambda + Function to transform the data in the inverse pass. + For the adapter to be serializable, this function has to be serializable + as well (see Notes). Therefore, only proper functions and no lambda + functions can be used here. + + Raises + ------ + ValueError + When the provided functions are not registered serializable functions. + + Notes + ----- + Important: The forward and inverse functions have to be registered with Keras. + To do so, use the `@keras.saving.register_keras_serializable` decorator. + They must also be registered (and identical) when loading the adapter + at a later point in time. + + """ + + def __init__( + self, + *, + forward: Callable[[np.ndarray, ...], np.ndarray], + inverse: Callable[[np.ndarray, ...], np.ndarray], + ): + super().__init__() + + self._check_serializable(forward, label="forward") + self._check_serializable(inverse, label="inverse") + self._forward = forward + self._inverse = inverse + + @classmethod + def _check_serializable(cls, function, label=""): + GENERAL_EXAMPLE_CODE = ( + "The example code below shows the structure of a correctly decorated function:\n\n" + "```\n" + "import keras\n\n" + "@keras.saving.register_keras_serializable('custom')\n" + f"def my_{label}(...):\n" + " [your code goes here...]\n" + "```\n" + ) + if function is None: + raise TypeError( + f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}" + ) + registered_name = get_registered_name(function) + # check if function is a lambda function + if registered_name == "": + raise ValueError( + f"The provided function for '{label}' is a lambda function, " + "which cannot be serialized. " + "Please provide a registered serializable function by using the " + "@keras.saving.register_keras_serializable decorator." + f"\n{GENERAL_EXAMPLE_CODE}" + ) + if inspect.ismethod(function): + raise ValueError( + f"The provided value for '{label}' is a method, not a function. " + "Methods cannot be serialized separately from their classes. " + "Please provide a registered serializable function instead by " + "moving the functionality to a function (i.e., outside of the class) and " + "using the @keras.saving.register_keras_serializable decorator." + f"\n{GENERAL_EXAMPLE_CODE}" + ) + registered_object_for_name = get_registered_object(registered_name) + if registered_object_for_name is None: + try: + source_max_lines = 5 + function_source_code = inspect.getsource(function).split("\n") + if len(function_source_code) > source_max_lines: + function_source_code = function_source_code[:source_max_lines] + [" [...]"] + + example_code = "For your provided function, this would look like this:\n\n" + example_code += "\n".join( + ["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"] + + function_source_code + + ["```"] + ) + except OSError: + example_code = GENERAL_EXAMPLE_CODE + raise ValueError( + f"The provided function for '{label}' is not registered with Keras.\n" + "Please register the function using the " + "@keras.saving.register_keras_serializable decorator.\n" + f"{example_code}" + ) + if registered_object_for_name is not function: + raise ValueError( + f"The provided function for '{label}' does not match the function " + f"registered under its name '{registered_name}'. " + f"(registered function: {registered_object_for_name}, provided function: {function}). " + ) + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform": + if get_registered_object(config["forward"]["config"], custom_objects) is None: + provided_function_msg = "" + if config["_forward_source_code"]: + provided_function_msg = ( + f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```" + ) + raise TypeError( + "\n\nPLEASE READ HERE:\n" + "-----------------\n" + "The forward function that was provided as `forward` " + "is not registered with Keras, making deserialization impossible. " + f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original " + "function before loading your model." + f"{provided_function_msg}" + ) + if get_registered_object(config["inverse"]["config"], custom_objects) is None: + provided_function_msg = "" + if config["_inverse_source_code"]: + provided_function_msg = ( + f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```" + ) + raise TypeError( + "\n\nPLEASE READ HERE:\n" + "-----------------\n" + "The inverse function that was provided as `inverse` " + "is not registered with Keras, making deserialization impossible. " + f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original " + "function before loading your model." + f"{provided_function_msg}" + ) + forward = deserialize(config["forward"], custom_objects) + inverse = deserialize(config["inverse"], custom_objects) + return cls( + forward=forward, + inverse=inverse, + ) + + def get_config(self) -> dict: + forward_source_code = inverse_source_code = None + try: + forward_source_code = inspect.getsource(self._forward) + inverse_source_code = inspect.getsource(self._inverse) + except OSError: + pass + return { + "forward": serialize(self._forward), + "inverse": serialize(self._inverse), + "_forward_source_code": forward_source_code, + "_inverse_source_code": inverse_source_code, + } + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + # filter kwargs so that other transform args like batch_size, strict, ... are not passed through + kwargs = filter_kwargs(kwargs, self._forward) + return self._forward(data, **kwargs) + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + kwargs = filter_kwargs(kwargs, self._inverse) + return self._inverse(data, **kwargs) diff --git a/bayesflow/experimental/free_form_flow/free_form_flow.py b/bayesflow/experimental/free_form_flow/free_form_flow.py index a4ad3c8be..b2dd911b7 100644 --- a/bayesflow/experimental/free_form_flow/free_form_flow.py +++ b/bayesflow/experimental/free_form_flow/free_form_flow.py @@ -12,6 +12,7 @@ vjp, serialize_value_or_type, deserialize_value_or_type, + weighted_mean, ) from bayesflow.networks import InferenceNetwork @@ -240,6 +241,6 @@ def decode(z): reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1) losses = maximum_likelihood_loss + self.beta * reconstruction_loss - loss = self.aggregate(losses, sample_weight) + loss = weighted_mean(losses, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 18fefdc91..c983a5cde 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -7,7 +7,13 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type, weighted_sum +from bayesflow.utils import ( + find_network, + keras_kwargs, + serialize_value_or_type, + deserialize_value_or_type, + weighted_mean, +) from ..inference_network import InferenceNetwork @@ -331,6 +337,6 @@ def compute_metrics( # Pseudo-huber loss, see [2], Section 3.3 loss = lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber) - loss = weighted_sum(loss, sample_weight) + loss = weighted_mean(loss, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index d50432ea6..403d92511 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -7,7 +7,7 @@ keras_kwargs, serialize_value_or_type, deserialize_value_or_type, - weighted_sum, + weighted_mean, ) from .actnorm import ActNorm @@ -167,6 +167,6 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) z, log_density = self(x, conditions=conditions, inverse=False, density=True) - loss = weighted_sum(-log_density, sample_weight) + loss = weighted_mean(-log_density, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index f8a39629b..85c338d82 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -13,7 +13,7 @@ optimal_transport, serialize_value_or_type, deserialize_value_or_type, - weighted_sum, + weighted_mean, ) from ..inference_network import InferenceNetwork @@ -260,6 +260,6 @@ def compute_metrics( predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training") loss = self.loss_fn(target_velocity, predicted_velocity) - loss = weighted_sum(loss, sample_weight) + loss = weighted_mean(loss, sample_weight) return base_metrics | {"loss": loss} diff --git a/bayesflow/scores/normed_difference_score.py b/bayesflow/scores/normed_difference_score.py index 9659d3726..eb2795927 100644 --- a/bayesflow/scores/normed_difference_score.py +++ b/bayesflow/scores/normed_difference_score.py @@ -2,6 +2,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor +from bayesflow.utils import weighted_mean from .scoring_rule import ScoringRule @@ -55,7 +56,7 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor = """ estimates = estimates["value"] scores = keras.ops.absolute(estimates - targets) ** self.k - score = self.aggregate(scores, weights) + score = weighted_mean(scores, weights) return score def get_config(self): diff --git a/bayesflow/scores/parametric_distribution_score.py b/bayesflow/scores/parametric_distribution_score.py index bdd62d692..3ead3271f 100644 --- a/bayesflow/scores/parametric_distribution_score.py +++ b/bayesflow/scores/parametric_distribution_score.py @@ -1,6 +1,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor +from bayesflow.utils import weighted_mean from .scoring_rule import ScoringRule @@ -29,5 +30,5 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor = :math:`S(\hat p_\phi, \theta; k) = -\log(\hat p_\phi(\theta))` """ scores = -self.log_prob(x=targets, **estimates) - score = self.aggregate(scores, weights) + score = weighted_mean(scores, weights) return score diff --git a/bayesflow/scores/quantile_score.py b/bayesflow/scores/quantile_score.py index 2e3ec54ef..b05a35fc5 100644 --- a/bayesflow/scores/quantile_score.py +++ b/bayesflow/scores/quantile_score.py @@ -4,7 +4,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor -from bayesflow.utils import logging +from bayesflow.utils import logging, weighted_mean from bayesflow.links import OrderedQuantiles from .scoring_rule import ScoringRule @@ -39,7 +39,7 @@ def get_config(self): base_config = super().get_config() return base_config | self.config - def get_head_shapes_from_target_shape(self, target_shape: Shape): + def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, tuple]: # keras.saving.load_model sometimes passes target_shape as a list, so we force a conversion target_shape = tuple(target_shape) return dict(value=(len(self.q),) + target_shape[1:]) @@ -49,5 +49,5 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor = pointwise_differance = estimates - targets[:, None, :] scores = pointwise_differance * (keras.ops.cast(pointwise_differance > 0, float) - self._q[None, :, None]) scores = keras.ops.mean(scores, axis=1) - score = self.aggregate(scores, weights) + score = weighted_mean(scores, weights) return score diff --git a/bayesflow/scores/scoring_rule.py b/bayesflow/scores/scoring_rule.py index 87ad216e9..a1a3f5717 100644 --- a/bayesflow/scores/scoring_rule.py +++ b/bayesflow/scores/scoring_rule.py @@ -204,30 +204,3 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor) """ raise NotImplementedError - - def aggregate(self, scores: Tensor, weights: Tensor = None) -> Tensor: - """ - Computes the mean of **scores**, optionally applying **weights**. - - This function computes the mean value of the given scores. When weights are provided, - it first multiplies the scores by the weights and then computes the mean of the result. - If no weights are provided, it computes the mean of the scores. - - Parameters - ---------- - scores : Tensor - A tensor containing the scores to be aggregated. - weights : Tensor, optional (default - None) - A tensor of weights corresponding to each score. Must be the same shape as `scores`. - If not provided, the function returns the mean of `scores`. - - Returns - ------- - Tensor - The aggregated score computed as a weighted mean if **weights** is provided, - or as the simple mean of **scores** otherwise. - """ - - if weights is not None: - return keras.ops.mean(scores * weights) - return keras.ops.mean(scores) diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 5bdaea274..1e1092e0c 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -71,7 +71,7 @@ tree_concatenate, tree_stack, fill_triangular_matrix, - weighted_sum, + weighted_mean, ) from .classification import calibration_curve, confusion_matrix from .validators import check_lengths_same diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a932424f5..5e3b407ec 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -1,8 +1,9 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import partial import keras +import numpy as np from typing import Literal from bayesflow.types import Tensor @@ -233,21 +234,62 @@ def body(_state, _time, _step_size, _step): return state +def integrate_scheduled( + fn: Callable, + state: dict[str, ArrayLike], + steps: Tensor | np.ndarray, + method: str = "rk45", + **kwargs, +) -> dict[str, ArrayLike]: + match method: + case "euler": + step_fn = euler_step + case "rk45": + step_fn = rk45_step + case str() as name: + raise ValueError(f"Unknown integration method name: {name!r}") + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + + def body(_loop_var, _loop_state): + _time = steps[_loop_var] + step_size = steps[_loop_var + 1] - steps[_loop_var] + + _loop_state, _, _ = step_fn(_loop_state, _time, step_size) + return _loop_state + + state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + return state + + def integrate( fn: Callable, state: dict[str, ArrayLike], - start_time: ArrayLike, - stop_time: ArrayLike, + start_time: ArrayLike | None = None, + stop_time: ArrayLike | None = None, min_steps: int = 10, max_steps: int = 10_000, - steps: int | Literal["adaptive"] = 100, + steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, method: str = "euler", **kwargs, ) -> dict[str, ArrayLike]: - match steps: - case "adaptive" | "dynamic": - return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) - case int(): - return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) - case _: - raise RuntimeError("Type or value of `steps` not understood.") + if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: + if start_time is None or stop_time is None: + raise ValueError( + "Please provide start_time and stop_time for the integration, was " + f"'start_time={start_time}', 'stop_time={stop_time}'." + ) + return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) + elif isinstance(steps, int): + if start_time is None or stop_time is None: + raise ValueError( + "Please provide start_time and stop_time for the integration, was " + f"'start_time={start_time}', 'stop_time={stop_time}'." + ) + return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) + elif isinstance(steps, Sequence) or isinstance(steps, np.ndarray) or keras.ops.is_tensor(steps): + return integrate_scheduled(fn, state, steps, method, **kwargs) + else: + raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") diff --git a/bayesflow/utils/tensor_utils.py b/bayesflow/utils/tensor_utils.py index cd6f6d4ca..4d89249b7 100644 --- a/bayesflow/utils/tensor_utils.py +++ b/bayesflow/utils/tensor_utils.py @@ -140,7 +140,7 @@ def pad(x: Tensor, value: float | Tensor, n: int, axis: int, side: str = "both") raise TypeError(f"Invalid side type {type(side)!r}. Must be str.") -def weighted_sum(elements: Tensor, weights: Tensor = None) -> Tensor: +def weighted_mean(elements: Tensor, weights: Tensor = None) -> Tensor: """ Compute the (optionally) weighted mean of the input tensor. diff --git a/bayesflow/wrappers/mamba/mamba_block.py b/bayesflow/wrappers/mamba/mamba_block.py index 17617e33d..f629807a0 100644 --- a/bayesflow/wrappers/mamba/mamba_block.py +++ b/bayesflow/wrappers/mamba/mamba_block.py @@ -55,8 +55,8 @@ def __init__( super().__init__(**keras_kwargs(kwargs)) - # if keras.backend.backend() != "torch": - # raise RuntimeError("Mamba is only available using torch backend.") + if keras.backend.backend() != "torch": + raise RuntimeError("Mamba is only available using torch backend.") try: from mamba_ssm import Mamba diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index abd5797bd..d379d7cc4 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -5,6 +5,11 @@ @pytest.fixture() def adapter(): from bayesflow.adapters import Adapter + import keras + + @keras.saving.register_keras_serializable("custom") + def serializable_fn(x): + return x d = ( Adapter() @@ -20,6 +25,7 @@ def adapter(): .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") .apply(include="p2", forward="log1p") + .apply_serializable(include="x", forward=serializable_fn, inverse=serializable_fn) .scale("x", by=[-1, 2]) .shift("x", by=2) .standardize(exclude=["t1", "t2", "o1"]) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 3dea0baf4..0d17c419f 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -3,6 +3,7 @@ serialize_keras_object as serialize, ) import numpy as np +import pytest def test_cycle_consistency(adapter, random_data): @@ -110,3 +111,69 @@ def test_simple_transforms(random_data): assert np.allclose(inverse["t1"], random_data["t1"]) assert np.allclose(inverse["p1"], random_data["p1"]) + + +def test_custom_transform(): + # test that transform raises errors in all relevant cases + import keras + from bayesflow.adapters.transforms import SerializableCustomTransform + from copy import deepcopy + + class A: + @classmethod + def fn(cls, x): + return x + + def not_registered_fn(x): + return x + + @keras.saving.register_keras_serializable("custom") + def registered_fn(x): + return x + + @keras.saving.register_keras_serializable("custom") + def registered_but_changed(x): + return x + + def registered_but_changed(x): # noqa: F811 + return 2 * x + + # method instead of function provided + with pytest.raises(ValueError): + SerializableCustomTransform(forward=A.fn, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=A.fn) + + # lambda function provided + with pytest.raises(ValueError): + SerializableCustomTransform(forward=lambda x: x, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=lambda x: x) + + # unregistered function provided + with pytest.raises(ValueError): + SerializableCustomTransform(forward=not_registered_fn, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=not_registered_fn) + + # function does not match registered function + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_but_changed, inverse=registered_fn) + with pytest.raises(ValueError): + SerializableCustomTransform(forward=registered_fn, inverse=registered_but_changed) + + transform = SerializableCustomTransform(forward=registered_fn, inverse=registered_fn) + serialized_transform = keras.saving.serialize_keras_object(transform) + keras.saving.deserialize_keras_object(serialized_transform) + + # modify name of the forward function so that it cannot be found + corrupt_serialized_transform = deepcopy(serialized_transform) + corrupt_serialized_transform["config"]["forward"]["config"] = "nonexistent" + with pytest.raises(TypeError): + keras.saving.deserialize_keras_object(corrupt_serialized_transform) + + # modify name of the inverse transform so that it cannot be found + corrupt_serialized_transform = deepcopy(serialized_transform) + corrupt_serialized_transform["config"]["inverse"]["config"] = "nonexistent" + with pytest.raises(TypeError): + keras.saving.deserialize_keras_object(corrupt_serialized_transform) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_utils/conftest.py b/tests/test_utils/conftest.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py new file mode 100644 index 000000000..8147d0589 --- /dev/null +++ b/tests/test_utils/test_integrate.py @@ -0,0 +1,11 @@ +def test_scheduled_integration(): + import keras + from bayesflow.utils import integrate + + def fn(t, x): + return {"x": t**2} + + steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0]) + approximate_result = 0.0 + 0.5**2 * 0.5 + result = integrate(fn, {"x": 0.0}, steps=steps)["x"] + assert result == approximate_result