-
Notifications
You must be signed in to change notification settings - Fork 78
Add class for custom transforms to adapter, similar to LambdaTransform.
#399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vpratz
merged 4 commits into
bayesflow-org:dev
from
vpratz:feat-serializable-custom-transform
Apr 12, 2025
Merged
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
2ba93d4
Add class for custom transforms to adapter.
vpratz 8fdfde9
custom transform: less verbose naming, fix tests
vpratz 0591391
Merge branch 'dev' into feat-serializable-custom-transform
vpratz 630e9cf
Merge remote-tracking branch 'upstream/dev' into feat-serializable-cu…
vpratz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
184 changes: 184 additions & 0 deletions
184
bayesflow/adapters/transforms/serializable_custom_transform.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| from collections.abc import Callable | ||
| import numpy as np | ||
| from keras.saving import ( | ||
| deserialize_keras_object as deserialize, | ||
| register_keras_serializable as serializable, | ||
| serialize_keras_object as serialize, | ||
| get_registered_name, | ||
| get_registered_object, | ||
| ) | ||
| from .elementwise_transform import ElementwiseTransform | ||
| from ...utils import filter_kwargs | ||
| import inspect | ||
|
|
||
|
|
||
| @serializable(package="bayesflow.adapters") | ||
| class SerializableCustomTransform(ElementwiseTransform): | ||
| """ | ||
| Transforms a parameter using a pair of registered serializable forward and inverse functions. | ||
| Parameters | ||
| ---------- | ||
| serializable_forward_fn : function, no lambda | ||
| Registered serializable function to transform the data in the forward pass. | ||
| For the adapter to be serializable, this function has to be serializable | ||
| as well (see Notes). Therefore, only proper functions and no lambda | ||
| functions can be used here. | ||
| serializable_inverse_fn : function, no lambda | ||
| Function to transform the data in the inverse pass. | ||
| For the adapter to be serializable, this function has to be serializable | ||
| as well (see Notes). Therefore, only proper functions and no lambda | ||
| functions can be used here. | ||
| Raises | ||
| ------ | ||
| ValueError | ||
| When the provided functions are not registered serializable functions. | ||
| Notes | ||
| ----- | ||
| Important: The forward and inverse functions have to be registered with Keras. | ||
| To do so, use the `@keras.saving.register_keras_serializable` decorator. | ||
| They must also be registered (and identical) when loading the adapter | ||
| at a later point in time. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| serializable_forward_fn: Callable[[np.ndarray, ...], np.ndarray], | ||
| serializable_inverse_fn: Callable[[np.ndarray, ...], np.ndarray], | ||
| ): | ||
| super().__init__() | ||
|
|
||
| self._check_serializable(serializable_forward_fn, label="serializable_forward_fn") | ||
| self._check_serializable(serializable_inverse_fn, label="serializable_inverse_fn") | ||
| self._forward = serializable_forward_fn | ||
| self._inverse = serializable_inverse_fn | ||
|
|
||
| @classmethod | ||
| def _check_serializable(cls, function, label=""): | ||
| GENERAL_EXAMPLE_CODE = f"""The example code below shows the structure of a correctly decorated function: | ||
| ``` | ||
| import keras | ||
| @keras.saving.register_keras_serializable('custom') | ||
| def my_{label}(...): | ||
| [your code goes here...] | ||
| ``` | ||
| """ | ||
vpratz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if function is None: | ||
| raise TypeError( | ||
| f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}" | ||
| ) | ||
| registered_name = get_registered_name(function) | ||
| # check if function is a lambda function | ||
| if registered_name == "<lambda>": | ||
| raise ValueError( | ||
| f"The provided function for '{label}' is a lambda function, " | ||
| "which cannot be serialized. " | ||
| "Please provide a registered serializable function by using the " | ||
| "@keras.saving.register_keras_serializable decorator." | ||
| f"\n{GENERAL_EXAMPLE_CODE}" | ||
| ) | ||
| if inspect.ismethod(function): | ||
| raise ValueError( | ||
| f"The provided value for '{label}' is a method, not a function. " | ||
| "Methods cannot be serialized separately from their classes. " | ||
| "Please provide a registered serializable function instead by " | ||
| "moving the functionality to a function (i.e., outside of the class) and " | ||
| "using the @keras.saving.register_keras_serializable decorator." | ||
| f"\n{GENERAL_EXAMPLE_CODE}" | ||
| ) | ||
| registered_object_for_name = get_registered_object(registered_name) | ||
| if registered_object_for_name is None: | ||
| try: | ||
| source_max_lines = 5 | ||
| function_source_code = inspect.getsource(function).split("\n") | ||
| if len(function_source_code) > source_max_lines: | ||
| function_source_code = function_source_code[:source_max_lines] + [" [...]"] | ||
|
|
||
| example_code = "For your provided function, this would look like this:\n\n" | ||
| example_code += "\n".join( | ||
| ["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"] | ||
| + function_source_code | ||
| + ["```"] | ||
| ) | ||
| except OSError: | ||
| example_code = GENERAL_EXAMPLE_CODE | ||
| raise ValueError( | ||
| f"The provided function for '{label}' is not registered with Keras.\n" | ||
| "Please register the function using the " | ||
| "@keras.saving.register_keras_serializable decorator.\n" | ||
| f"{example_code}" | ||
| ) | ||
| if registered_object_for_name is not function: | ||
| raise ValueError( | ||
| f"The provided function for '{label}' does not match the function " | ||
| f"registered under its name '{registered_name}'. " | ||
| f"(registered function: {registered_object_for_name}, provided function: {function}). " | ||
| ) | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform": | ||
| if get_registered_object(config["forward"]["config"], custom_objects) is None: | ||
| provided_function_msg = "" | ||
| if config["_forward_source_code"]: | ||
| provided_function_msg = ( | ||
| f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```" | ||
| ) | ||
| raise TypeError( | ||
| "\n\nPLEASE READ HERE:\n" | ||
| "-----------------\n" | ||
| "The forward function that was provided as `serializable_forward_fn` " | ||
| "is not registered with Keras, making deserialization impossible. " | ||
| f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original " | ||
| "function before loading your model." | ||
| f"{provided_function_msg}" | ||
| ) | ||
| if get_registered_object(config["inverse"]["config"], custom_objects) is None: | ||
| provided_function_msg = "" | ||
| if config["_inverse_source_code"]: | ||
| provided_function_msg = ( | ||
| f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```" | ||
| ) | ||
| raise TypeError( | ||
| "\n\nPLEASE READ HERE:\n" | ||
| "-----------------\n" | ||
| "The inverse function that was provided as `serializable_inverse_fn` " | ||
| "is not registered with Keras, making deserialization impossible. " | ||
| f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original " | ||
| "function before loading your model." | ||
| f"{provided_function_msg}" | ||
| ) | ||
| forward = deserialize(config["forward"], custom_objects) | ||
| inverse = deserialize(config["inverse"], custom_objects) | ||
| return cls( | ||
| serializable_forward_fn=forward, | ||
| serializable_inverse_fn=inverse, | ||
| ) | ||
|
|
||
| def get_config(self) -> dict: | ||
| forward_source_code = inverse_source_code = None | ||
| try: | ||
| forward_source_code = inspect.getsource(self._forward) | ||
| inverse_source_code = inspect.getsource(self._inverse) | ||
| except OSError: | ||
| pass | ||
| return { | ||
| "forward": serialize(self._forward), | ||
| "inverse": serialize(self._inverse), | ||
| "_forward_source_code": forward_source_code, | ||
| "_inverse_source_code": inverse_source_code, | ||
| } | ||
|
|
||
| def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: | ||
| # filter kwargs so that other transform args like batch_size, strict, ... are not passed through | ||
| kwargs = filter_kwargs(kwargs, self._forward) | ||
| return self._forward(data, **kwargs) | ||
|
|
||
| def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: | ||
| kwargs = filter_kwargs(kwargs, self._inverse) | ||
| return self._inverse(data, **kwargs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.