Skip to content

Commit 1ae6bda

Browse files
committed
optimize serializable decorator by getting the parent package name with sys._getframe instead of inspect.getmodule
1 parent 50de632 commit 1ae6bda

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

bayesflow/utils/serialization.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import keras
66
import numpy as np
7+
import sys
78

89
# this import needs to be exactly like this to work with monkey patching
910
from keras.saving import deserialize_keras_object
@@ -97,7 +98,10 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
9798
# we marked this as a type during serialization
9899
obj = obj[len(_type_prefix) :]
99100
tp = keras.saving.get_registered_object(
100-
obj, custom_objects=custom_objects, module_objects=builtins.__dict__ | np.__dict__
101+
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
102+
obj,
103+
custom_objects=custom_objects,
104+
module_objects=np.__dict__ | builtins.__dict__,
101105
)
102106
if tp is None:
103107
raise ValueError(
@@ -117,10 +121,9 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
117121
@allow_args
118122
def serializable(cls, package=None, name=None):
119123
if package is None:
120-
# get the calling module's name, e.g. "bayesflow.networks.inference_network"
121-
stack = inspect.stack()
122-
module = inspect.getmodule(stack[1][0])
123-
package = copy(module.__name__)
124+
frame = sys._getframe(1)
125+
g = frame.f_globals
126+
package = g.get("__name__", "bayesflow")
124127

125128
if name is None:
126129
name = copy(cls.__name__)

0 commit comments

Comments
 (0)