44import numpy as np
55from keras .saving import (
66 deserialize_keras_object as deserialize ,
7+ get_registered_name ,
8+ get_registered_object ,
79 register_keras_serializable as serializable ,
810 serialize_keras_object as serialize ,
911)
@@ -79,21 +81,33 @@ def extra_repr(self) -> str:
7981
8082 @classmethod
8183 def from_config (cls , config : dict , custom_objects = None ) -> "Transform" :
82- def transform_constructor (* args , ** kwargs ):
83- raise RuntimeError (
84- "Instantiating new elementwise transforms on a deserialized FilterTransform is not yet supported (and"
85- "may never be). As a work-around, you can manually register the elementwise transform constructor after"
86- "deserialization:\n "
87- "obj = deserialize(config)\n "
88- "obj.transform_constructor = MyElementwiseTransform"
89- )
90-
84+ transform_constructor = get_registered_object (config ["transform_constructor" ])
85+ try :
86+ kwargs = deserialize (config ["kwargs" ])
87+ except TypeError as e :
88+ if transform_constructor .__name__ == "LambdaTransform" :
89+ raise TypeError (
90+ "LambdaTransform (created by Adapter.apply) could not be deserialized.\n "
91+ "This is probably because the custom transform functions `forward` and "
92+ "`backward` from `Adapter.apply` were not passed as `custom_objects`.\n "
93+ "For example, if your adapter uses\n "
94+ "`Adapter.apply(forward=forward_transform, inverse=inverse_transform)`,\n "
95+ "you have to pass\n "
96+ '`custom_objects={"forward_transform": forward_transform, '
97+ '"inverse_transform": inverse_transform}`\n '
98+ "to the function you use to load the serialized object."
99+ ) from e
100+ raise TypeError (
101+ "The transform could not be deserialized properly. "
102+ "The most likely reason is that some classes or functions "
103+ "are not known during deserialization. Please pass them as `custom_objects`."
104+ ) from e
91105 instance = cls (
92106 transform_constructor = transform_constructor ,
93107 predicate = deserialize (config ["predicate" ], custom_objects ),
94108 include = deserialize (config ["include" ], custom_objects ),
95109 exclude = deserialize (config ["exclude" ], custom_objects ),
96- ** config [ " kwargs" ] ,
110+ ** kwargs ,
97111 )
98112
99113 instance .transform_map = deserialize (config ["transform_map" ])
@@ -102,6 +116,7 @@ def transform_constructor(*args, **kwargs):
102116
103117 def get_config (self ) -> dict :
104118 return {
119+ "transform_constructor" : get_registered_name (self .transform_constructor ),
105120 "predicate" : serialize (self .predicate ),
106121 "include" : serialize (self .include ),
107122 "exclude" : serialize (self .exclude ),
0 commit comments