44import inspect
55import keras
66import numpy as np
7+ import sys
78
89# this import needs to be exactly like this to work with monkey patching
910from keras .saving import deserialize_keras_object
@@ -97,7 +98,10 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
9798 # we marked this as a type during serialization
9899 obj = obj [len (_type_prefix ) :]
99100 tp = keras .saving .get_registered_object (
100- obj , custom_objects = custom_objects , module_objects = builtins .__dict__ | np .__dict__
101+ # TODO: can we pass module objects without overwriting numpy's dict with builtins?
102+ obj ,
103+ custom_objects = custom_objects ,
104+ module_objects = np .__dict__ | builtins .__dict__ ,
101105 )
102106 if tp is None :
103107 raise ValueError (
@@ -117,10 +121,9 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
117121@allow_args
118122def serializable (cls , package = None , name = None ):
119123 if package is None :
120- # get the calling module's name, e.g. "bayesflow.networks.inference_network"
121- stack = inspect .stack ()
122- module = inspect .getmodule (stack [1 ][0 ])
123- package = copy (module .__name__ )
124+ frame = sys ._getframe (1 )
125+ g = frame .f_globals
126+ package = g .get ("__name__" , "bayesflow" )
124127
125128 if name is None :
126129 name = copy (cls .__name__ )
0 commit comments