Skip to content

Commit 8895928

Browse files
committed
[no ci] Merge remote-tracking branch 'upstream/main' into docs-migration-advice
2 parents 9a2f338 + d31a761 commit 8895928

File tree

13 files changed

+117
-120
lines changed

13 files changed

+117
-120
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ class FlowMatching(InferenceNetwork):
3939
}
4040

4141
OPTIMAL_TRANSPORT_DEFAULT_CONFIG = {
42-
"method": "sinkhorn",
43-
"cost": "euclidean",
42+
"method": "log_sinkhorn",
4443
"regularization": 0.1,
4544
"max_steps": 100,
46-
"tolerance": 1e-4,
45+
"atol": 1e-5,
46+
"rtol": 1e-4,
4747
}
4848

4949
INTEGRATE_DEFAULT_CONFIG = {
Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,31 @@
1-
import inspect
1+
import sys
2+
import types
23

34

45
def _add_imports_to_all(include_modules: bool | list[str] = False, exclude: list[str] | None = None):
56
"""Add all global variables to __all__"""
67
if not isinstance(include_modules, (bool, list)):
78
raise ValueError("include_modules must be a boolean or a list of strings")
89

9-
exclude = exclude or []
10-
calling_module = inspect.stack()[1]
11-
local_stack = calling_module[0]
12-
global_vars = local_stack.f_globals
13-
all_vars = global_vars["__all__"] if "__all__" in global_vars else []
14-
included_vars = []
15-
for var_name in set(global_vars.keys()):
16-
if inspect.ismodule(global_vars[var_name]):
17-
if include_modules is True and var_name not in exclude and not var_name.startswith("_"):
18-
included_vars.append(var_name)
19-
elif isinstance(include_modules, list) and var_name in include_modules:
20-
included_vars.append(var_name)
21-
elif var_name not in exclude and not var_name.startswith("_"):
22-
included_vars.append(var_name)
23-
global_vars["__all__"] = sorted(list(set(all_vars).union(included_vars)))
10+
exclude_set = set(exclude or [])
11+
contains = exclude_set.__contains__
12+
mod_type = types.ModuleType
13+
frame = sys._getframe(1)
14+
g: dict = frame.f_globals
15+
existing = set(g.get("__all__", []))
16+
17+
to_add = []
18+
include_list = include_modules if isinstance(include_modules, list) else ()
19+
inc_all = include_modules is True
20+
21+
for name, val in g.items():
22+
if name.startswith("_") or contains(name):
23+
continue
24+
25+
if isinstance(val, mod_type):
26+
if inc_all or name in include_list:
27+
to_add.append(name)
28+
else:
29+
to_add.append(name)
30+
31+
g["__all__"] = sorted(existing.union(to_add))

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import keras
22

33
from .. import logging
4-
from ..tensor_utils import is_symbolic_tensor
54

65
from .euclidean import euclidean
76

@@ -27,9 +26,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8,
2726

2827
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
2928

30-
if is_symbolic_tensor(log_plan):
31-
return log_plan
32-
3329
def contains_nans(plan):
3430
return keras.ops.any(keras.ops.isnan(plan))
3531

@@ -59,7 +55,7 @@ def do_nothing():
5955
def log_steps():
6056
msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
6157

62-
logging.info(msg, steps)
58+
logging.debug(msg, steps)
6359

6460
def warn_convergence():
6561
marginals = keras.ops.logsumexp(log_plan, axis=0)

bayesflow/utils/serialization.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import keras
66
import numpy as np
7+
import sys
78

89
# this import needs to be exactly like this to work with monkey patching
910
from keras.saving import deserialize_keras_object
@@ -97,7 +98,10 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
9798
# we marked this as a type during serialization
9899
obj = obj[len(_type_prefix) :]
99100
tp = keras.saving.get_registered_object(
100-
obj, custom_objects=custom_objects, module_objects=builtins.__dict__ | np.__dict__
101+
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
102+
obj,
103+
custom_objects=custom_objects,
104+
module_objects=np.__dict__ | builtins.__dict__,
101105
)
102106
if tp is None:
103107
raise ValueError(
@@ -117,10 +121,9 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
117121
@allow_args
118122
def serializable(cls, package=None, name=None):
119123
if package is None:
120-
# get the calling module's name, e.g. "bayesflow.networks.inference_network"
121-
stack = inspect.stack()
122-
module = inspect.getmodule(stack[1][0])
123-
package = copy(module.__name__)
124+
frame = sys._getframe(1)
125+
g = frame.f_globals
126+
package = g.get("__name__", "bayesflow")
124127

125128
if name is None:
126129
name = copy(cls.__name__)

docsrc/source/index.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,15 @@ More tutorials are always welcome! Please consider making a pull request if you
6767
6868
.. tab-item:: pip
6969
70-
The v2 version is not available on PyPI yet, please install from source.
70+
.. code-block:: bash
71+
72+
pip install bayesflow
7173
7274
.. tab-item:: source
7375
7476
.. code-block:: bash
7577
76-
pip install git+https://github.com/bayesflow-org/bayesflow.git
78+
pip install git+https://github.com/bayesflow-org/bayesflow.git@dev
7779
```
7880

7981
### Backend

examples/Linear_Regression_Starter.ipynb

Lines changed: 14 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -28,30 +28,17 @@
2828
},
2929
{
3030
"cell_type": "code",
31-
"execution_count": 1,
31+
"execution_count": null,
3232
"metadata": {
3333
"ExecuteTime": {
3434
"end_time": "2025-02-14T10:51:27.573003Z",
3535
"start_time": "2025-02-14T10:51:27.568939Z"
3636
}
3737
},
38-
"outputs": [
39-
{
40-
"name": "stderr",
41-
"output_type": "stream",
42-
"text": [
43-
"WARNING:bayesflow:\n",
44-
"When using torch backend, we need to disable autograd by default to avoid excessive memory usage. Use\n",
45-
"\n",
46-
"with torch.enable_grad():\n",
47-
" ...\n",
48-
"\n",
49-
"in contexts where you need gradients (e.g. custom training loops).\n"
50-
]
51-
}
52-
],
38+
"outputs": [],
5339
"source": [
5440
"import numpy as np\n",
41+
"from pathlib import Path\n",
5542
"\n",
5643
"import keras\n",
5744
"import bayesflow as bf"
@@ -598,7 +585,7 @@
598585
},
599586
{
600587
"cell_type": "code",
601-
"execution_count": 19,
588+
"execution_count": null,
602589
"metadata": {
603590
"ExecuteTime": {
604591
"end_time": "2025-02-14T10:52:51.132695Z",
@@ -618,7 +605,7 @@
618605
}
619606
],
620607
"source": [
621-
"f = bf.diagnostics.plots.loss(history, )"
608+
"f = bf.diagnostics.plots.loss(history)"
622609
]
623610
},
624611
{
@@ -964,30 +951,14 @@
964951
},
965952
{
966953
"cell_type": "code",
967-
"execution_count": 30,
954+
"execution_count": null,
968955
"metadata": {},
969-
"outputs": [
970-
{
971-
"name": "stderr",
972-
"output_type": "stream",
973-
"text": [
974-
"2025-04-21 11:54:04.969579: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
975-
"2025-04-21 11:54:04.977366: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
976-
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
977-
"E0000 00:00:1745250844.984817 4140753 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
978-
"E0000 00:00:1745250844.987174 4140753 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
979-
"W0000 00:00:1745250844.993850 4140753 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
980-
"W0000 00:00:1745250844.993860 4140753 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
981-
"W0000 00:00:1745250844.993861 4140753 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
982-
"W0000 00:00:1745250844.993863 4140753 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
983-
"2025-04-21 11:54:04.996047: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
984-
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
985-
]
986-
}
987-
],
956+
"outputs": [],
988957
"source": [
989958
"# Recommended - full serialization (checkpoints folder must exist)\n",
990-
"workflow.approximator.save(filepath=\"checkpoints/regression.keras\")\n",
959+
"filepath = Path(\"checkpoints\") / \"regression.keras\"\n",
960+
"filepath.parent.mkdir(exist_ok=True)\n",
961+
"workflow.approximator.save(filepath=filepath)\n",
991962
"\n",
992963
"# Not recommended due to adapter mismatches - weights only\n",
993964
"# approximator.save_weights(filepath=\"checkpoints/regression.h5\")"
@@ -1002,21 +973,12 @@
1002973
},
1003974
{
1004975
"cell_type": "code",
1005-
"execution_count": 31,
976+
"execution_count": null,
1006977
"metadata": {},
1007-
"outputs": [
1008-
{
1009-
"name": "stderr",
1010-
"output_type": "stream",
1011-
"text": [
1012-
"/home/radevs/anaconda3/envs/bf/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py:734: UserWarning: `compile()` was not called as part of model loading because the model's `compile()` method is custom. All subclassed Models that have `compile()` overridden should also override `get_compile_config()` and `compile_from_config(config)`. Alternatively, you can call `compile()` manually after loading.\n",
1013-
" instance.compile_from_config(compile_config)\n"
1014-
]
1015-
}
1016-
],
978+
"outputs": [],
1017979
"source": [
1018980
"# Load approximator\n",
1019-
"approximator = keras.saving.load_model(\"checkpoints/regression.keras\")"
981+
"approximator = keras.saving.load_model(filepath)"
1020982
]
1021983
},
1022984
{
@@ -1052,13 +1014,6 @@
10521014
" variable_names=par_names\n",
10531015
")"
10541016
]
1055-
},
1056-
{
1057-
"cell_type": "code",
1058-
"execution_count": null,
1059-
"metadata": {},
1060-
"outputs": [],
1061-
"source": []
10621017
}
10631018
],
10641019
"metadata": {
@@ -1073,16 +1028,7 @@
10731028
"name": "python3"
10741029
},
10751030
"language_info": {
1076-
"codemirror_mode": {
1077-
"name": "ipython",
1078-
"version": 3
1079-
},
1080-
"file_extension": ".py",
1081-
"mimetype": "text/x-python",
1082-
"name": "python",
1083-
"nbconvert_exporter": "python",
1084-
"pygments_lexer": "ipython3",
1085-
"version": "3.11.11"
1031+
"name": "python"
10861032
},
10871033
"widgets": {
10881034
"application/vnd.jupyter.widget-state+json": {

examples/Lotka_Volterra_Point_Estimation_and_Expert_Stats.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"source": [
3838
"import matplotlib.pyplot as plt\n",
3939
"import numpy as np\n",
40+
"from pathlib import Path\n",
4041
"import seaborn as sns\n",
4142
"\n",
4243
"import scipy\n",
@@ -748,7 +749,8 @@
748749
"metadata": {},
749750
"outputs": [],
750751
"source": [
751-
"checkpoint_path = \"checkpoints/model.keras\"\n",
752+
"checkpoint_path = Path(\"checkpoints\") / \"model.keras\"\n",
753+
"checkpoint_path.parent.mkdir(exist_ok=True)\n",
752754
"keras.saving.save_model(point_inference_workflow.approximator, checkpoint_path)"
753755
]
754756
},

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "bayesflow"
7-
version = "2.0.0"
7+
version = "2.0.1"
88
authors = [{ name = "The BayesFlow Team" }]
99
classifiers = [
1010
"Development Status :: 5 - Production/Stable",
@@ -37,6 +37,8 @@ all = [
3737
"jupyter",
3838
"jupyterlab",
3939
"nbconvert",
40+
"ipython",
41+
"ipykernel",
4042
"pre-commit",
4143
"ruff",
4244
"tox",
@@ -72,6 +74,8 @@ docs = [
7274
]
7375
test = [
7476
"nbconvert",
77+
"ipython",
78+
"ipykernel",
7579
"pytest",
7680
"pytest-cov",
7781
"pytest-rerunfailures",

tests/test_examples/test_examples.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from tests.utils import run_notebook
44

55

6+
@pytest.mark.skip(reason="requires setting up Stan")
67
@pytest.mark.slow
78
def test_bayesian_experimental_design(examples_path):
89
run_notebook(examples_path / "Bayesian_Experimental_Design.ipynb")
@@ -30,7 +31,7 @@ def test_one_sample_ttest(examples_path):
3031

3132
@pytest.mark.slow
3233
def test_sir_posterior_estimation(examples_path):
33-
run_notebook(examples_path / "SIR_Posterior_estimation.ipynb")
34+
run_notebook(examples_path / "SIR_Posterior_Estimation.ipynb")
3435

3536

3637
@pytest.mark.slow

tests/test_networks/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def typical_point_inference_network_subnet():
8585
"spline_coupling_flow",
8686
"flow_matching",
8787
"free_form_flow",
88+
"consistency_model",
8889
],
8990
scope="function",
9091
)
@@ -106,7 +107,8 @@ def inference_network_subnet(request):
106107

107108

108109
@pytest.fixture(
109-
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow"], scope="function"
110+
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow", "consistency_model"],
111+
scope="function",
110112
)
111113
def generative_inference_network(request):
112114
return request.getfixturevalue(request.param)

0 commit comments

Comments
 (0)