From 1ae6bda74850acdcd1c796bb6410f5c84a59ecdf Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 22 Apr 2025 15:17:08 -0400 Subject: [PATCH 1/3] optimize serializable decorator by getting the parent package name with sys._getframe instead of inspect.getmodule --- bayesflow/utils/serialization.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index c6e804c74..500264f05 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -4,6 +4,7 @@ import inspect import keras import numpy as np +import sys # this import needs to be exactly like this to work with monkey patching from keras.saving import deserialize_keras_object @@ -97,7 +98,10 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs): # we marked this as a type during serialization obj = obj[len(_type_prefix) :] tp = keras.saving.get_registered_object( - obj, custom_objects=custom_objects, module_objects=builtins.__dict__ | np.__dict__ + # TODO: can we pass module objects without overwriting numpy's dict with builtins? + obj, + custom_objects=custom_objects, + module_objects=np.__dict__ | builtins.__dict__, ) if tp is None: raise ValueError( @@ -117,10 +121,9 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs): @allow_args def serializable(cls, package=None, name=None): if package is None: - # get the calling module's name, e.g. "bayesflow.networks.inference_network" - stack = inspect.stack() - module = inspect.getmodule(stack[1][0]) - package = copy(module.__name__) + frame = sys._getframe(1) + g = frame.f_globals + package = g.get("__name__", "bayesflow") if name is None: name = copy(cls.__name__) From 988f6a4aebeba5475248100b45b9c9d0a574f12e Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 22 Apr 2025 15:18:54 -0400 Subject: [PATCH 2/3] optimize _add_imports_to_all --- bayesflow/utils/_docs/_populate_all.py | 40 +++++++++++++++----------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/bayesflow/utils/_docs/_populate_all.py b/bayesflow/utils/_docs/_populate_all.py index 50da8048b..d9718c2f5 100644 --- a/bayesflow/utils/_docs/_populate_all.py +++ b/bayesflow/utils/_docs/_populate_all.py @@ -1,4 +1,5 @@ -import inspect +import sys +import types def _add_imports_to_all(include_modules: bool | list[str] = False, exclude: list[str] | None = None): @@ -6,18 +7,25 @@ def _add_imports_to_all(include_modules: bool | list[str] = False, exclude: list if not isinstance(include_modules, (bool, list)): raise ValueError("include_modules must be a boolean or a list of strings") - exclude = exclude or [] - calling_module = inspect.stack()[1] - local_stack = calling_module[0] - global_vars = local_stack.f_globals - all_vars = global_vars["__all__"] if "__all__" in global_vars else [] - included_vars = [] - for var_name in set(global_vars.keys()): - if inspect.ismodule(global_vars[var_name]): - if include_modules is True and var_name not in exclude and not var_name.startswith("_"): - included_vars.append(var_name) - elif isinstance(include_modules, list) and var_name in include_modules: - included_vars.append(var_name) - elif var_name not in exclude and not var_name.startswith("_"): - included_vars.append(var_name) - global_vars["__all__"] = sorted(list(set(all_vars).union(included_vars))) + exclude_set = set(exclude or []) + contains = exclude_set.__contains__ + mod_type = types.ModuleType + frame = sys._getframe(1) + g: dict = frame.f_globals + existing = set(g.get("__all__", [])) + + to_add = [] + include_list = include_modules if isinstance(include_modules, list) else () + inc_all = include_modules is True + + for name, val in g.items(): + if name.startswith("_") or contains(name): + continue + + if isinstance(val, mod_type): + if inc_all or name in include_list: + to_add.append(name) + else: + to_add.append(name) + + g["__all__"] = sorted(existing.union(to_add)) From 0e54bf7cc03837d8ad2192f0b139e42da05bf43b Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 22 Apr 2025 15:19:13 -0400 Subject: [PATCH 3/3] fix numpy-keras interop in tests --- tests/test_utils/test_dispatch.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_utils/test_dispatch.py b/tests/test_utils/test_dispatch.py index 8fc0f91f0..85e326445 100644 --- a/tests/test_utils/test_dispatch.py +++ b/tests/test_utils/test_dispatch.py @@ -3,6 +3,7 @@ # Import the dispatch functions from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net +from tests.utils import assert_allclose # --- Tests for find_network.py --- @@ -118,23 +119,21 @@ def test_find_pooling_mean(): # Check that a keras Lambda layer is returned assert isinstance(pooling, keras.layers.Lambda) # Test that the lambda function produces a mean when applied to a sample tensor. - import numpy as np - sample = np.array([[1, 2], [3, 4]]) + sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]]) # Keras Lambda layers expect tensors via call(), here we simply call the layer's function. result = pooling.call(sample) - np.testing.assert_allclose(result, sample.mean(axis=-2)) + assert_allclose(result, keras.ops.mean(sample, axis=-2)) @pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)]) def test_find_pooling_max_min(name, func): pooling = find_pooling(name) assert isinstance(pooling, keras.layers.Lambda) - import numpy as np - sample = np.array([[1, 2], [3, 4]]) + sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]]) result = pooling.call(sample) - np.testing.assert_allclose(result, func(sample, axis=-2)) + assert_allclose(result, func(sample, axis=-2)) def test_find_pooling_learnable(monkeypatch):