11from copy import copy
22
3+ import builtins
34import inspect
45import keras
6+ import numpy as np
57
8+ # this import needs to be exactly like this to work with monkey patching
9+ from keras .saving import deserialize_keras_object
10+
11+ from .context_managers import monkey_patch
612from .decorators import allow_args
713
814
915PREFIX = "_bayesflow_"
1016
17+ _type_prefix = "__bayesflow_type__"
18+
1119
1220def serialize_value_or_type (config , name , obj ):
1321 """Serialize an object that can be either a value or a type
@@ -83,10 +91,27 @@ def deserialize_value_or_type(config, name):
8391 return updated_config
8492
8593
86- def deserialize (obj , custom_objects = None , module_objects = None ):
87- if inspect .isclass (obj ):
88- return keras .saving .get_registered_object (obj , custom_objects = custom_objects , module_objects = module_objects )
89- return keras .saving .deserialize_keras_object (obj , custom_objects = custom_objects , module_objects = module_objects )
94+ def deserialize (obj , custom_objects = None , safe_mode = True , ** kwargs ):
95+ with monkey_patch (deserialize_keras_object , deserialize ) as original_deserialize :
96+ if isinstance (obj , str ) and obj .startswith (_type_prefix ):
97+ # we marked this as a type during serialization
98+ obj = obj [len (_type_prefix ) :]
99+ tp = keras .saving .get_registered_object (
100+ obj , custom_objects = custom_objects , module_objects = builtins .__dict__ | np .__dict__
101+ )
102+ if tp is None :
103+ raise ValueError (
104+ f"Could not deserialize type { obj !r} . Make sure it is registered with "
105+ f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`."
106+ )
107+ return tp
108+ if inspect .isclass (obj ):
109+ # add this base case since keras does not cover it
110+ return obj
111+
112+ obj = original_deserialize (obj , custom_objects = custom_objects , safe_mode = safe_mode , ** kwargs )
113+
114+ return obj
90115
91116
92117@allow_args
@@ -107,7 +132,7 @@ def serializable(cls, package=None, name=None):
107132def serialize (obj ):
108133 if isinstance (obj , (tuple , list , dict )):
109134 return keras .tree .map_structure (serialize , obj )
135+ elif inspect .isclass (obj ):
136+ return _type_prefix + keras .saving .get_registered_name (obj )
110137
111- if inspect .isclass (obj ):
112- return keras .saving .get_registered_name (obj )
113138 return keras .saving .serialize_keras_object (obj )
0 commit comments