Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions bayesflow/utils/_docs/_populate_all.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import inspect
import sys
import types


def _add_imports_to_all(include_modules: bool | list[str] = False, exclude: list[str] | None = None):
"""Add all global variables to __all__"""
if not isinstance(include_modules, (bool, list)):
raise ValueError("include_modules must be a boolean or a list of strings")

exclude = exclude or []
calling_module = inspect.stack()[1]
local_stack = calling_module[0]
global_vars = local_stack.f_globals
all_vars = global_vars["__all__"] if "__all__" in global_vars else []
included_vars = []
for var_name in set(global_vars.keys()):
if inspect.ismodule(global_vars[var_name]):
if include_modules is True and var_name not in exclude and not var_name.startswith("_"):
included_vars.append(var_name)
elif isinstance(include_modules, list) and var_name in include_modules:
included_vars.append(var_name)
elif var_name not in exclude and not var_name.startswith("_"):
included_vars.append(var_name)
global_vars["__all__"] = sorted(list(set(all_vars).union(included_vars)))
exclude_set = set(exclude or [])
contains = exclude_set.__contains__
mod_type = types.ModuleType
frame = sys._getframe(1)
g: dict = frame.f_globals
existing = set(g.get("__all__", []))

to_add = []
include_list = include_modules if isinstance(include_modules, list) else ()
inc_all = include_modules is True

for name, val in g.items():
if name.startswith("_") or contains(name):
continue

if isinstance(val, mod_type):
if inc_all or name in include_list:
to_add.append(name)
else:
to_add.append(name)

g["__all__"] = sorted(existing.union(to_add))
13 changes: 8 additions & 5 deletions bayesflow/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import keras
import numpy as np
import sys

# this import needs to be exactly like this to work with monkey patching
from keras.saving import deserialize_keras_object
Expand Down Expand Up @@ -97,7 +98,10 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
# we marked this as a type during serialization
obj = obj[len(_type_prefix) :]
tp = keras.saving.get_registered_object(
obj, custom_objects=custom_objects, module_objects=builtins.__dict__ | np.__dict__
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
obj,
custom_objects=custom_objects,
module_objects=np.__dict__ | builtins.__dict__,
)
if tp is None:
raise ValueError(
Expand All @@ -117,10 +121,9 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
@allow_args
def serializable(cls, package=None, name=None):
if package is None:
# get the calling module's name, e.g. "bayesflow.networks.inference_network"
stack = inspect.stack()
module = inspect.getmodule(stack[1][0])
package = copy(module.__name__)
frame = sys._getframe(1)
g = frame.f_globals
package = g.get("__name__", "bayesflow")

if name is None:
name = copy(cls.__name__)
Expand Down
11 changes: 5 additions & 6 deletions tests/test_utils/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Import the dispatch functions
from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net
from tests.utils import assert_allclose

# --- Tests for find_network.py ---

Expand Down Expand Up @@ -118,23 +119,21 @@ def test_find_pooling_mean():
# Check that a keras Lambda layer is returned
assert isinstance(pooling, keras.layers.Lambda)
# Test that the lambda function produces a mean when applied to a sample tensor.
import numpy as np

sample = np.array([[1, 2], [3, 4]])
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
# Keras Lambda layers expect tensors via call(), here we simply call the layer's function.
result = pooling.call(sample)
np.testing.assert_allclose(result, sample.mean(axis=-2))
assert_allclose(result, keras.ops.mean(sample, axis=-2))


@pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)])
def test_find_pooling_max_min(name, func):
pooling = find_pooling(name)
assert isinstance(pooling, keras.layers.Lambda)
import numpy as np

sample = np.array([[1, 2], [3, 4]])
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
result = pooling.call(sample)
np.testing.assert_allclose(result, func(sample, axis=-2))
assert_allclose(result, func(sample, axis=-2))


def test_find_pooling_learnable(monkeypatch):
Expand Down