Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions bayesflow/utils/_docs/_populate_all.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import inspect
import sys
import types


def _add_imports_to_all(include_modules: bool | list[str] = False, exclude: list[str] | None = None):
"""Add all global variables to __all__"""
if not isinstance(include_modules, (bool, list)):
raise ValueError("include_modules must be a boolean or a list of strings")

exclude = exclude or []
calling_module = inspect.stack()[1]
local_stack = calling_module[0]
global_vars = local_stack.f_globals
all_vars = global_vars["__all__"] if "__all__" in global_vars else []
included_vars = []
for var_name in set(global_vars.keys()):
if inspect.ismodule(global_vars[var_name]):
if include_modules is True and var_name not in exclude and not var_name.startswith("_"):
included_vars.append(var_name)
elif isinstance(include_modules, list) and var_name in include_modules:
included_vars.append(var_name)
elif var_name not in exclude and not var_name.startswith("_"):
included_vars.append(var_name)
global_vars["__all__"] = sorted(list(set(all_vars).union(included_vars)))
exclude_set = set(exclude or [])
contains = exclude_set.__contains__
mod_type = types.ModuleType
frame = sys._getframe(1)
g: dict = frame.f_globals
existing = set(g.get("__all__", []))

to_add = []
include_list = include_modules if isinstance(include_modules, list) else ()
inc_all = include_modules is True

for name, val in g.items():
if name.startswith("_") or contains(name):
continue

if isinstance(val, mod_type):
if inc_all or name in include_list:
to_add.append(name)
else:
to_add.append(name)

g["__all__"] = sorted(existing.union(to_add))
13 changes: 8 additions & 5 deletions bayesflow/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import keras
import numpy as np
import sys

# this import needs to be exactly like this to work with monkey patching
from keras.saving import deserialize_keras_object
Expand Down Expand Up @@ -97,7 +98,10 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
# we marked this as a type during serialization
obj = obj[len(_type_prefix) :]
tp = keras.saving.get_registered_object(
obj, custom_objects=custom_objects, module_objects=builtins.__dict__ | np.__dict__
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
obj,
custom_objects=custom_objects,
module_objects=np.__dict__ | builtins.__dict__,
)
if tp is None:
raise ValueError(
Expand All @@ -117,10 +121,9 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
@allow_args
def serializable(cls, package=None, name=None):
if package is None:
# get the calling module's name, e.g. "bayesflow.networks.inference_network"
stack = inspect.stack()
module = inspect.getmodule(stack[1][0])
package = copy(module.__name__)
frame = sys._getframe(1)
g = frame.f_globals
package = g.get("__name__", "bayesflow")

if name is None:
name = copy(cls.__name__)
Expand Down
7 changes: 5 additions & 2 deletions examples/Linear_Regression_Starter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"from pathlib import Path\n",
"\n",
"import keras\n",
"import bayesflow as bf"
Expand Down Expand Up @@ -955,7 +956,9 @@
"outputs": [],
"source": [
"# Recommended - full serialization (checkpoints folder must exist)\n",
"workflow.approximator.save(filepath=\"checkpoints/regression.keras\")\n",
"filepath = Path(\"checkpoints\") / \"regression.keras\"\n",
"filepath.parent.mkdir(exist_ok=True)\n",
"workflow.approximator.save(filepath=filepath)\n",
"\n",
"# Not recommended due to adapter mismatches - weights only\n",
"# approximator.save_weights(filepath=\"checkpoints/regression.h5\")"
Expand All @@ -975,7 +978,7 @@
"outputs": [],
"source": [
"# Load approximator\n",
"approximator = keras.saving.load_model(\"checkpoints/regression.keras\")"
"approximator = keras.saving.load_model(filepath)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from pathlib import Path\n",
"import seaborn as sns\n",
"\n",
"import scipy\n",
Expand Down Expand Up @@ -748,7 +749,8 @@
"metadata": {},
"outputs": [],
"source": [
"checkpoint_path = \"checkpoints/model.keras\"\n",
"checkpoint_path = Path(\"checkpoints\") / \"model.keras\"\n",
"checkpoint_path.parent.mkdir(exist_ok=True)\n",
"keras.saving.save_model(point_inference_workflow.approximator, checkpoint_path)"
]
},
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ all = [
"jupyter",
"jupyterlab",
"nbconvert",
"ipython",
"ipykernel",
"pre-commit",
"ruff",
"tox",
Expand Down Expand Up @@ -72,6 +74,8 @@ docs = [
]
test = [
"nbconvert",
"ipython",
"ipykernel",
"pytest",
"pytest-cov",
"pytest-rerunfailures",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tests.utils import run_notebook


@pytest.mark.skip(reason="requires setting up Stan")
@pytest.mark.slow
def test_bayesian_experimental_design(examples_path):
run_notebook(examples_path / "Bayesian_Experimental_Design.ipynb")
Expand Down Expand Up @@ -30,7 +31,7 @@ def test_one_sample_ttest(examples_path):

@pytest.mark.slow
def test_sir_posterior_estimation(examples_path):
run_notebook(examples_path / "SIR_Posterior_estimation.ipynb")
run_notebook(examples_path / "SIR_Posterior_Estimation.ipynb")


@pytest.mark.slow
Expand Down
4 changes: 3 additions & 1 deletion tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def typical_point_inference_network_subnet():
"spline_coupling_flow",
"flow_matching",
"free_form_flow",
"consistency_model",
],
scope="function",
)
Expand All @@ -106,7 +107,8 @@ def inference_network_subnet(request):


@pytest.fixture(
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow"], scope="function"
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow", "consistency_model"],
scope="function",
)
def generative_inference_network(request):
return request.getfixturevalue(request.param)
Expand Down
38 changes: 29 additions & 9 deletions tests/test_networks/test_inference_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,21 @@ def test_variable_batch_size(inference_network, random_samples, random_condition
else:
new_conditions = keras.ops.zeros((bs,) + keras.ops.shape(random_conditions)[1:])

inference_network(new_input, conditions=new_conditions)
try:
inference_network(new_input, conditions=new_conditions)
except NotImplementedError:
# network is not invertible
pass
inference_network(new_input, conditions=new_conditions, inverse=True)


@pytest.mark.parametrize("density", [True, False])
def test_output_structure(density, generative_inference_network, random_samples, random_conditions):
output = generative_inference_network(random_samples, conditions=random_conditions, density=density)
try:
output = generative_inference_network(random_samples, conditions=random_conditions, density=density)
except NotImplementedError:
# network not invertible
return

if density:
assert isinstance(output, tuple)
Expand All @@ -57,9 +65,13 @@ def test_output_structure(density, generative_inference_network, random_samples,


def test_output_shape(generative_inference_network, random_samples, random_conditions):
forward_output, forward_log_density = generative_inference_network(
random_samples, conditions=random_conditions, density=True
)
try:
forward_output, forward_log_density = generative_inference_network(
random_samples, conditions=random_conditions, density=True
)
except NotImplementedError:
# network is not invertible, not forward function available
return

assert keras.ops.shape(forward_output) == keras.ops.shape(random_samples)
assert keras.ops.shape(forward_log_density) == (keras.ops.shape(random_samples)[0],)
Expand All @@ -74,9 +86,13 @@ def test_output_shape(generative_inference_network, random_samples, random_condi

def test_cycle_consistency(generative_inference_network, random_samples, random_conditions):
# cycle-consistency means the forward and inverse methods are inverses of each other
forward_output, forward_log_density = generative_inference_network(
random_samples, conditions=random_conditions, density=True
)
try:
forward_output, forward_log_density = generative_inference_network(
random_samples, conditions=random_conditions, density=True
)
except NotImplementedError:
# network is not invertible, cycle consistency cannot be tested.
return
inverse_output, inverse_log_density = generative_inference_network(
forward_output, conditions=random_conditions, density=True, inverse=True
)
Expand All @@ -88,7 +104,11 @@ def test_cycle_consistency(generative_inference_network, random_samples, random_
def test_density_numerically(generative_inference_network, random_samples, random_conditions):
from bayesflow.utils import jacobian

output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True)
try:
output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True)
except NotImplementedError:
# network does not support density estimation
return

def f(x):
return generative_inference_network(x, conditions=random_conditions)
Expand Down
11 changes: 5 additions & 6 deletions tests/test_utils/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Import the dispatch functions
from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net
from tests.utils import assert_allclose

# --- Tests for find_network.py ---

Expand Down Expand Up @@ -118,23 +119,21 @@ def test_find_pooling_mean():
# Check that a keras Lambda layer is returned
assert isinstance(pooling, keras.layers.Lambda)
# Test that the lambda function produces a mean when applied to a sample tensor.
import numpy as np

sample = np.array([[1, 2], [3, 4]])
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
# Keras Lambda layers expect tensors via call(), here we simply call the layer's function.
result = pooling.call(sample)
np.testing.assert_allclose(result, sample.mean(axis=-2))
assert_allclose(result, keras.ops.mean(sample, axis=-2))


@pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)])
def test_find_pooling_max_min(name, func):
pooling = find_pooling(name)
assert isinstance(pooling, keras.layers.Lambda)
import numpy as np

sample = np.array([[1, 2], [3, 4]])
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
result = pooling.call(sample)
np.testing.assert_allclose(result, func(sample, axis=-2))
assert_allclose(result, func(sample, axis=-2))


def test_find_pooling_learnable(monkeypatch):
Expand Down
18 changes: 16 additions & 2 deletions tests/utils/jupyter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import nbformat
from nbconvert.preprocessors import ExecutePreprocessor

from pathlib import Path
import shutil


def run_notebook(path):
path = Path(path)
checkpoint_path = path.parent / "checkpoints"
# only clean up if the directory did not exist before the test
cleanup_checkpoints = not checkpoint_path.exists()
with open(str(path)) as f:
nb = nbformat.read(f, nbformat.NO_CONVERT)

kernel = ExecutePreprocessor(timeout=600, kernel_name="python3")
kernel = ExecutePreprocessor(timeout=600, kernel_name="python3", resources={"metadata": {"path": path.parent}})

try:
result = kernel.preprocess(nb)
finally:
if cleanup_checkpoints and checkpoint_path.exists():
# clean up if the directory was created by the test
shutil.rmtree(checkpoint_path)

return kernel.preprocess(nb)
return result