Skip to content

Commit 5a8d624

Browse files
committed
use monkey-patching to enable type deserialization
1 parent d07687f commit 5a8d624

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

bayesflow/utils/serialization.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from copy import copy
22

3+
import builtins
34
import inspect
45
import keras
6+
import numpy as np
57

8+
# this import needs to be exactly like this to work with monkey patching
9+
from keras.saving import deserialize_keras_object
10+
11+
from .context_managers import monkey_patch
612
from .decorators import allow_args
713

814

915
PREFIX = "_bayesflow_"
1016

17+
_type_prefix = "__bayesflow_type__"
18+
1119

1220
def serialize_value_or_type(config, name, obj):
1321
"""Serialize an object that can be either a value or a type
@@ -83,10 +91,27 @@ def deserialize_value_or_type(config, name):
8391
return updated_config
8492

8593

86-
def deserialize(obj, custom_objects=None, module_objects=None):
87-
if inspect.isclass(obj):
88-
return keras.saving.get_registered_object(obj, custom_objects=custom_objects, module_objects=module_objects)
89-
return keras.saving.deserialize_keras_object(obj, custom_objects=custom_objects, module_objects=module_objects)
94+
def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
95+
with monkey_patch(deserialize_keras_object, deserialize) as original_deserialize:
96+
if isinstance(obj, str) and obj.startswith(_type_prefix):
97+
# we marked this as a type during serialization
98+
obj = obj[len(_type_prefix) :]
99+
tp = keras.saving.get_registered_object(
100+
obj, custom_objects=custom_objects, module_objects=builtins.__dict__ | np.__dict__
101+
)
102+
if tp is None:
103+
raise ValueError(
104+
f"Could not deserialize type {obj!r}. Make sure it is registered with "
105+
f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`."
106+
)
107+
return tp
108+
if inspect.isclass(obj):
109+
# add this base case since keras does not cover it
110+
return obj
111+
112+
obj = original_deserialize(obj, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs)
113+
114+
return obj
90115

91116

92117
@allow_args
@@ -107,7 +132,7 @@ def serializable(cls, package=None, name=None):
107132
def serialize(obj):
108133
if isinstance(obj, (tuple, list, dict)):
109134
return keras.tree.map_structure(serialize, obj)
135+
elif inspect.isclass(obj):
136+
return _type_prefix + keras.saving.get_registered_name(obj)
110137

111-
if inspect.isclass(obj):
112-
return keras.saving.get_registered_name(obj)
113138
return keras.saving.serialize_keras_object(obj)

0 commit comments

Comments
 (0)