From 142ff2f7510e2572823dc7d41fc0699aacca8f28 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 22 Apr 2025 10:53:55 +0000 Subject: [PATCH 01/46] [no ci] Add advice regarding moving from v1 to v2 to README. Raise awareness regarding missing features and incompatibility between the versions. Similar changes can be made for the docs. --- README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/README.md b/README.md index 3f3269754..64bccaff0 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,33 @@ It provides users and researchers with: BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference fueled by continuous progress in generative AI and Bayesian inference. +## Migrating from BayesFlow 1.x to BayesFlow 2.0+ + +You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library. +While it shares the same overall goals with the 1.x versions, the API is not compatible. +You can find the most recent version of BayesFlow 1.x on the `stable-legacy` branch. +The latest [BayesFlow 1.x documentation](https://bayesflow.org/stable-legacy/index.html) can be accessed by selecting the "stable-legacy" entry in the version picker of the documentation. + +> [!CAUTION] +> You should not upgrade (yet) if one of the following applies: +> +> - You have an ongoing project that uses BayesFlow 1.x, and you do not want to allocate time for migrating to the new API. +> - You require a feature that was not ported to BayesFlow 2.0+ yet. To our knowledge, this applies to: +> * Two-level/Hierarchical models: `TwoLevelGenerativeModel`, `TwoLevelPrior`. +> * Sensitivity analysis: functionality from the `bayesflow.sensitivity` module. +> * MCMC (discontinued): The `bayesflow.mcmc` module. We are considering other options to enable the use of BayesFlow in an MCMC setting. +> * Networks: `EvidentialNetwork`. +> * Model misspecification detection: MMD test in the summary space (see #384). +> - You have already trained models in BayesFlow 1.x, that you do not want to re-train with the new version. Loading models from version 1.x in version 2.0+ is not supported. +> +> If you encounter any functionality that is missing and not listed here, please let us know by opening an issue. + +The new version brings many features, like multi-backend support via Keras3, and improved modularity and extensibility. +We recommend to upgrade if none of the above conditions apply. +Continue reading below for installation instructions and examples to get started. +The [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb) guide highlights how concepts and classes relate between the two versions. +For additional information, please refer to the [FAQ](#faq) below. + ## Conceptual Overview
From 9a2f338fcd012fddd33a6e1b923d8b2dd451f496 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 22 Apr 2025 11:20:22 +0000 Subject: [PATCH 02/46] [no ci] Add link to migration advice to docs, add info about v1 --- docsrc/source/index.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docsrc/source/index.md b/docsrc/source/index.md index 36387c8fd..1dd1d5869 100644 --- a/docsrc/source/index.md +++ b/docsrc/source/index.md @@ -10,6 +10,9 @@ It provides users and researchers with: BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference fueled by continuous progress in generative AI and Bayesian inference. +To access the documentation for [BayesFlow version 1.x](https://github.com/bayesflow-org/bayesflow/tree/stable-legacy), select `stable-legacy` in the version picker above. +For advice on the migration from version 1.x to version 2+, please refer to the [README](https://github.com/bayesflow-org/bayesflow/blob/main/README.md). + ## Conceptual Overview
From 688f22ca31d404c09a84976ed754073c9a960fba Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Wed, 23 Apr 2025 11:48:01 +0000 Subject: [PATCH 03/46] [no ci] notebook tests: increase timeout, fix platform/backend dependent code Torch is very slow, so I had to increase the timeout accordingly. --- examples/From_ABC_to_BayesFlow.ipynb | 9 +++++++-- examples/SIR_Posterior_Estimation.ipynb | 6 +++++- examples/Two_Moons_Starter.ipynb | 6 +++++- tests/utils/jupyter.py | 4 ++-- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/examples/From_ABC_to_BayesFlow.ipynb b/examples/From_ABC_to_BayesFlow.ipynb index 334447555..b9757d9c4 100644 --- a/examples/From_ABC_to_BayesFlow.ipynb +++ b/examples/From_ABC_to_BayesFlow.ipynb @@ -38,7 +38,10 @@ "outputs": [], "source": [ "import numpy as np\n", - "import matplotlib.pyplot as plt" + "import matplotlib.pyplot as plt\n", + "import tempfile\n", + "from pathlib import Path\n", + "import platform" ] }, { @@ -322,7 +325,9 @@ ")\n", "\n", "# generate a temporary SQLite DB\n", - "abc_id = abc.new(\"sqlite:////tmp/mjp.db\", observations)" + "prefix = \"sqlite:///\" if platform.system() == \"Windows\" else \"sqlite:////\"\n", + "db_path = (Path(tempfile.gettempdir()).absolute() / \"mjp.db\").as_uri().replace(\"file:///\", prefix)\n", + "abc_id = abc.new(db_path, observations)" ] }, { diff --git a/examples/SIR_Posterior_Estimation.ipynb b/examples/SIR_Posterior_Estimation.ipynb index cadc597aa..c7dafa37f 100644 --- a/examples/SIR_Posterior_Estimation.ipynb +++ b/examples/SIR_Posterior_Estimation.ipynb @@ -19,7 +19,11 @@ "source": [ "import os\n", "# Set to your favorite backend\n", - "os.environ[\"KERAS_BACKEND\"] = \"jax\"" + "if \"KERAS_BACKEND\" not in os.environ:\n", + " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", + " os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", + "else:\n", + " print(f\"Using '{os.environ['KERAS_BACKEND']}' backend\")" ] }, { diff --git a/examples/Two_Moons_Starter.ipynb b/examples/Two_Moons_Starter.ipynb index 8fbb1d179..0d87f99c2 100644 --- a/examples/Two_Moons_Starter.ipynb +++ b/examples/Two_Moons_Starter.ipynb @@ -24,7 +24,11 @@ "source": [ "import os\n", "# Set to your favorite backend\n", - "os.environ[\"KERAS_BACKEND\"] = \"jax\"" + "if \"KERAS_BACKEND\" not in os.environ:\n", + " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", + " os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", + "else:\n", + " print(f\"Using '{os.environ['KERAS_BACKEND']}' backend\")" ] }, { diff --git a/tests/utils/jupyter.py b/tests/utils/jupyter.py index f905e1a0c..9a3b8d699 100644 --- a/tests/utils/jupyter.py +++ b/tests/utils/jupyter.py @@ -10,10 +10,10 @@ def run_notebook(path): checkpoint_path = path.parent / "checkpoints" # only clean up if the directory did not exist before the test cleanup_checkpoints = not checkpoint_path.exists() - with open(str(path)) as f: + with open(str(path), encoding="utf-8") as f: nb = nbformat.read(f, nbformat.NO_CONVERT) - kernel = ExecutePreprocessor(timeout=600, kernel_name="python3", resources={"metadata": {"path": path.parent}}) + kernel = ExecutePreprocessor(timeout=3600, kernel_name="python3", resources={"metadata": {"path": path.parent}}) try: result = kernel.preprocess(nb) From de300092875ff824283fef23fa6df8070fcd8940 Mon Sep 17 00:00:00 2001 From: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:39:02 +0200 Subject: [PATCH 04/46] Enable use of summary networks with functional API again (#434) * summary networks: add tests for using functional API * fix build functions for use with functional API --- bayesflow/links/ordered.py | 2 ++ bayesflow/networks/summary_network.py | 1 + bayesflow/networks/transformers/mab.py | 3 +++ bayesflow/networks/transformers/pma.py | 2 ++ bayesflow/networks/transformers/sab.py | 3 +++ bayesflow/utils/decorators.py | 7 +++++-- tests/test_networks/test_summary_networks.py | 22 ++++++++++++++++++++ 7 files changed, 38 insertions(+), 2 deletions(-) diff --git a/bayesflow/links/ordered.py b/bayesflow/links/ordered.py index 47be02317..77545b6f8 100644 --- a/bayesflow/links/ordered.py +++ b/bayesflow/links/ordered.py @@ -2,6 +2,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.utils import layer_kwargs +from bayesflow.utils.decorators import sanitize_input_shape @serializable(package="links.ordered") @@ -49,5 +50,6 @@ def call(self, inputs): x = keras.ops.concatenate([below, anchor_input, above], self.axis) return x + @sanitize_input_shape def compute_output_shape(self, input_shape): return input_shape diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index 316df39e6..6e97c618f 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -21,6 +21,7 @@ def build(self, input_shape): if self.base_distribution is not None: self.base_distribution.build(keras.ops.shape(z)) + @sanitize_input_shape def compute_output_shape(self, input_shape): return keras.ops.shape(self.call(keras.ops.zeros(input_shape))) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index a2e22da16..8f0e3f881 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -4,6 +4,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs +from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -122,8 +123,10 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - return out # noinspection PyMethodOverriding + @sanitize_input_shape def build(self, seq_x_shape, seq_y_shape): self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape)) + @sanitize_input_shape def compute_output_shape(self, seq_x_shape, seq_y_shape): return keras.ops.shape(self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape))) diff --git a/bayesflow/networks/transformers/pma.py b/bayesflow/networks/transformers/pma.py index 5eb6a269d..956c85b48 100644 --- a/bayesflow/networks/transformers/pma.py +++ b/bayesflow/networks/transformers/pma.py @@ -4,6 +4,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs +from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable from .mab import MultiHeadAttentionBlock @@ -125,5 +126,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: summaries = self.mab(seed_tiled, set_x_transformed, training=training, **kwargs) return ops.reshape(summaries, (ops.shape(summaries)[0], -1)) + @sanitize_input_shape def compute_output_shape(self, input_shape): return keras.ops.shape(self.call(keras.ops.zeros(input_shape))) diff --git a/bayesflow/networks/transformers/sab.py b/bayesflow/networks/transformers/sab.py index a69dc5fa4..a447d92a2 100644 --- a/bayesflow/networks/transformers/sab.py +++ b/bayesflow/networks/transformers/sab.py @@ -1,6 +1,7 @@ import keras from bayesflow.types import Tensor +from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable from .mab import MultiHeadAttentionBlock @@ -16,6 +17,7 @@ class SetAttentionBlock(MultiHeadAttentionBlock): """ # noinspection PyMethodOverriding + @sanitize_input_shape def build(self, input_set_shape): self.call(keras.ops.zeros(input_set_shape)) @@ -42,5 +44,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: return super().call(input_set, input_set, training=training, **kwargs) # noinspection PyMethodOverriding + @sanitize_input_shape def compute_output_shape(self, input_set_shape): return keras.ops.shape(self.call(keras.ops.zeros(input_set_shape))) diff --git a/bayesflow/utils/decorators.py b/bayesflow/utils/decorators.py index 91afc9fb7..7fd32edc9 100644 --- a/bayesflow/utils/decorators.py +++ b/bayesflow/utils/decorators.py @@ -114,7 +114,7 @@ def callback(x): def sanitize_input_shape(fn: Callable): - """Decorator to replace the first dimension in input_shape with a dummy batch size if it is None""" + """Decorator to replace the first dimension in ..._shape arguments with a dummy batch size if it is None""" # The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which # causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used @@ -126,5 +126,8 @@ def callback(input_shape: Shape) -> Shape: return tuple(input_shape) return input_shape - fn = argument_callback("input_shape", callback)(fn) + args = inspect.getfullargspec(fn).args + for arg in args: + if arg.endswith("_shape"): + fn = argument_callback(arg, callback)(fn) return fn diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index 082ce4d25..50e1726c1 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -25,6 +25,28 @@ def test_build(automatic, summary_network, random_set): assert summary_network.variables, "Model has no variables." +@pytest.mark.parametrize("automatic", [True, False]) +def test_build_functional_api(automatic, summary_network, random_set): + if summary_network is None: + pytest.skip(reason="Nothing to do, because there is no summary network.") + + assert summary_network.built is False + + inputs = keras.layers.Input(shape=keras.ops.shape(random_set)[1:]) + outputs = summary_network(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + if automatic: + model(random_set) + else: + model.build(keras.ops.shape(random_set)) + + assert model.built is True + + # check the model has variables + assert summary_network.variables, "Model has no variables." + + def test_variable_batch_size(summary_network, random_set): if summary_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.") From 0eefa695751511e0e7690e56942bfd53e41d1efa Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 25 Apr 2025 08:25:36 +0000 Subject: [PATCH 05/46] [no ci] docs: add GitHub and Discourse links, reorder navbar --- docsrc/source/conf.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/docsrc/source/conf.py b/docsrc/source/conf.py index cfbc931b9..4a21b3711 100644 --- a/docsrc/source/conf.py +++ b/docsrc/source/conf.py @@ -141,7 +141,29 @@ "image_light": "_static/bayesflow_hor.png", "image_dark": "_static/bayesflow_hor_dark.png", }, - "navbar_center": ["version-switcher", "navbar-nav"], + "icon_links_label": "Icon Links", + "icon_links": [ + { + "name": "GitHub", + "url": "https://github.com/bayesflow-org/bayesflow", + "icon": "fa-brands fa-square-github", + "type": "fontawesome", + }, + { + "name": "Discourse Forum", + "url": "https://discuss.bayesflow.org/", + "icon": "fa-brands fa-discourse", + "type": "fontawesome", + }, + ], + "navbar_align": "left", + # -- Template placement in theme layouts ---------------------------------- + "navbar_start": ["navbar-logo"], + # Note that the alignment of navbar_center is controlled by navbar_align + "navbar_center": ["navbar-nav"], + "navbar_end": ["theme-switcher", "navbar-icon-links", "version-switcher"], + # navbar_persistent is persistent right (even when on mobiles) + "navbar_persistent": ["search-button"], "switcher": { "json_url": "/versions.json", "version_match": current, From a97b5a2fc5f120505e43d0d9c8e55e12a8a1ebf6 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 25 Apr 2025 08:27:47 +0000 Subject: [PATCH 06/46] [no ci] docs: acknowledge scikit-learn website --- docsrc/source/index.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docsrc/source/index.md b/docsrc/source/index.md index f377318db..ef0675f78 100644 --- a/docsrc/source/index.md +++ b/docsrc/source/index.md @@ -237,6 +237,8 @@ If you are interested in a curated list of resources, including reviews, softwar This project is currently managed by researchers from Rensselaer Polytechnic Institute, TU Dortmund University, and Heidelberg University. It is partially funded by the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation, Project 528702768). The project is further supported by Germany's Excellence Strategy -- EXC-2075 - 390740016 (Stuttgart Cluster of Excellence SimTech) and EXC-2181 - 390900948 (Heidelberg Cluster of Excellence STRUCTURES), as well as the Informatics for Life initiative funded by the Klaus Tschira Foundation. +The [scikit-learn](https://scikit-learn.org/) website was a great resource and inspration for this site and the API documentation. We thank the scikit-learn community for sharing their configurations, which allowed us to include many nice features into this site as well. + ## License \& Source Code BayesFlow is released under {mainbranch}`MIT License `. From 8ac8aa314fb4a74bedf64a02533411052abb1935 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 25 Apr 2025 08:33:37 +0000 Subject: [PATCH 07/46] [no ci] docs: capitalize navigation headings --- docsrc/source/about.rst | 2 +- docsrc/source/index.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docsrc/source/about.rst b/docsrc/source/about.rst index e42c2e0b8..67ca0e102 100644 --- a/docsrc/source/about.rst +++ b/docsrc/source/about.rst @@ -1,4 +1,4 @@ -About us +About Us ======== Core maintainers diff --git a/docsrc/source/index.md b/docsrc/source/index.md index ef0675f78..f89c5ff3f 100644 --- a/docsrc/source/index.md +++ b/docsrc/source/index.md @@ -260,5 +260,5 @@ examples api/bayesflow about Contributing -Developer docs +Developer Docs ``` From e590a4313896864c75d6e6e9712a4cc8022970fc Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 25 Apr 2025 11:15:52 +0200 Subject: [PATCH 08/46] [no ci] README: move details on migration to FAQ --- README.md | 62 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 64bccaff0..12162dfa3 100644 --- a/README.md +++ b/README.md @@ -18,28 +18,14 @@ fueled by continuous progress in generative AI and Bayesian inference. You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library. While it shares the same overall goals with the 1.x versions, the API is not compatible. -You can find the most recent version of BayesFlow 1.x on the `stable-legacy` branch. -The latest [BayesFlow 1.x documentation](https://bayesflow.org/stable-legacy/index.html) can be accessed by selecting the "stable-legacy" entry in the version picker of the documentation. > [!CAUTION] -> You should not upgrade (yet) if one of the following applies: -> -> - You have an ongoing project that uses BayesFlow 1.x, and you do not want to allocate time for migrating to the new API. -> - You require a feature that was not ported to BayesFlow 2.0+ yet. To our knowledge, this applies to: -> * Two-level/Hierarchical models: `TwoLevelGenerativeModel`, `TwoLevelPrior`. -> * Sensitivity analysis: functionality from the `bayesflow.sensitivity` module. -> * MCMC (discontinued): The `bayesflow.mcmc` module. We are considering other options to enable the use of BayesFlow in an MCMC setting. -> * Networks: `EvidentialNetwork`. -> * Model misspecification detection: MMD test in the summary space (see #384). -> - You have already trained models in BayesFlow 1.x, that you do not want to re-train with the new version. Loading models from version 1.x in version 2.0+ is not supported. -> -> If you encounter any functionality that is missing and not listed here, please let us know by opening an issue. - -The new version brings many features, like multi-backend support via Keras3, and improved modularity and extensibility. -We recommend to upgrade if none of the above conditions apply. -Continue reading below for installation instructions and examples to get started. -The [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb) guide highlights how concepts and classes relate between the two versions. -For additional information, please refer to the [FAQ](#faq) below. +> A few features, most notably hierarchical models, have not been ported to BayesFlow 2.0+ +> yet. We are working on those features and plan to add them soon. You can find the complete +> list in the [FAQ](#faq) below. + +The [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb) guide +highlights how concepts and classes relate between the two versions. ## Conceptual Overview @@ -242,11 +228,47 @@ while the old version was based on TensorFlow. ------------- +**Question:** +Should I switch to BayesFlow 2.0+ now? Are there features that are still missing? + +**Answer:** +In general, we recommend to switch, as the new version is easier to use and will continue +to receive improvements and new features. However, a few features are still missing, so you +might want to wait until everything you need has been ported to BayesFlow 2.0+. + +Depending on your needs, you might not want to upgrade yet if one of the following applies: + +- You have an ongoing project that uses BayesFlow 1.x, and you do not want to allocate + time for migrating it to the new API. +- You have already trained models in BayesFlow 1.x, that you do not want to re-train + with the new version. Loading models from version 1.x in version 2.0+ is not supported. +- You require a feature that was not ported to BayesFlow 2.0+ yet. To our knowledge, + this applies to: + * Two-level/Hierarchical models: `TwoLevelGenerativeModel`, `TwoLevelPrior`. + * Sensitivity analysis: functionality from the `bayesflow.sensitivity` module. + * MCMC (discontinued): The `bayesflow.mcmc` module. We are considering other options + to enable the use of BayesFlow in an MCMC setting. + * Networks: `EvidentialNetwork`. + * Model misspecification detection: MMD test in the summary space (see #384). + +If you encounter any functionality that is missing and not listed here, please let us +know by opening an issue. + +------------- + **Question:** I still need the old BayesFlow for some of my projects. How can I install it? **Answer:** You can find and install the old Bayesflow version via the `stable-legacy` branch on GitHub. +The corresponding [documentation](https://bayesflow.org/stable-legacy/index.html) can be +accessed by selecting the "stable-legacy" entry in the version picker of the documentation. + +You can also install the latest version of BayesFlow v1.x from PyPI using + +``` +$ pip install "bayesflow<2.0" +``` ------------- From 076ceafe8049f127bc132355e4baf763eb95779f Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 25 Apr 2025 09:19:25 +0000 Subject: [PATCH 09/46] [no ci] Change heading on migration section --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 12162dfa3..68ea950da 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ It provides users and researchers with: BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference fueled by continuous progress in generative AI and Bayesian inference. -## Migrating from BayesFlow 1.x to BayesFlow 2.0+ +## Important Note for Existing Users You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library. While it shares the same overall goals with the 1.x versions, the API is not compatible. From 53dda82c7849b1ccfa21431ad1d440f9b7d01b39 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 25 Apr 2025 09:20:58 +0000 Subject: [PATCH 10/46] [no ci] docs: remove $ prefix from v1 install command --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 68ea950da..a147dbf4d 100644 --- a/README.md +++ b/README.md @@ -267,7 +267,7 @@ accessed by selecting the "stable-legacy" entry in the version picker of the doc You can also install the latest version of BayesFlow v1.x from PyPI using ``` -$ pip install "bayesflow<2.0" +pip install "bayesflow<2.0" ``` ------------- From 7ea287f08f452a0a0202c60e1fecd7c4ae505c5c Mon Sep 17 00:00:00 2001 From: Lars Date: Fri, 25 Apr 2025 08:36:13 -0400 Subject: [PATCH 11/46] More tests (#437) * fix docs of coupling flow * add additional tests --- .../networks/coupling_flow/coupling_flow.py | 2 +- .../test_coupling_flow/test_permutations.py | 117 ++++++++++++++++++ tests/test_networks/test_embeddings.py | 85 +++++++++++++ 3 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 tests/test_networks/test_coupling_flow/test_permutations.py create mode 100644 tests/test_networks/test_embeddings.py diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index c7b528987..ee78f180e 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -77,7 +77,7 @@ def __init__( The type of transformation used in the coupling layers, such as "affine". Default is "affine". permutation : str or None, optional - The type of permutation applied between layers. Can be "random" or None + The type of permutation applied between layers. Can be "orthogonal", "random", "swap", or None (no permutation). Default is "random". use_actnorm : bool, optional Whether to apply ActNorm before each coupling layer. Default is True. diff --git a/tests/test_networks/test_coupling_flow/test_permutations.py b/tests/test_networks/test_coupling_flow/test_permutations.py new file mode 100644 index 000000000..63c50ae2c --- /dev/null +++ b/tests/test_networks/test_coupling_flow/test_permutations.py @@ -0,0 +1,117 @@ +import pytest +import keras +import numpy as np + +from bayesflow.networks.coupling_flow.permutations import ( + FixedPermutation, + OrthogonalPermutation, + RandomPermutation, + Swap, +) + + +@pytest.fixture(params=[FixedPermutation, OrthogonalPermutation, RandomPermutation, Swap]) +def permutation_class(request): + return request.param + + +@pytest.fixture +def input_tensor(): + return keras.random.normal((2, 5)) + + +def test_fixed_permutation_build_and_call(): + # Since FixedPermutation is abstract, create a subclass for testing build. + class TestPerm(FixedPermutation): + def build(self, xz_shape, **kwargs): + length = xz_shape[-1] + self.forward_indices = keras.ops.arange(length - 1, -1, -1) + self.inverse_indices = keras.ops.arange(length - 1, -1, -1) + + layer = TestPerm() + input_shape = (2, 4) + layer.build(input_shape) + + x = keras.ops.convert_to_tensor(np.arange(8).reshape(input_shape).astype("float32")) + z, log_det = layer(x, inverse=False) + x_inv, log_det_inv = layer(z, inverse=True) + + # Check shape preservation + assert z.shape == x.shape + assert x_inv.shape == x.shape + # Forward then inverse recovers input + np.testing.assert_allclose(keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(x), atol=1e-5) + # log_det values should be zero tensors with the correct shape + assert tuple(log_det.shape) == input_shape[:-1] + assert tuple(log_det_inv.shape) == input_shape[:-1] + + +def test_orthogonal_permutation_build_and_call(input_tensor): + layer = OrthogonalPermutation() + input_shape = keras.ops.shape(input_tensor) + layer.build(input_shape) + + z, log_det = layer(input_tensor) + x_inv, log_det_inv = layer(z, inverse=True) + + # Check output shapes + assert z.shape == input_tensor.shape + assert x_inv.shape == input_tensor.shape + + # Forward + inverse should approximately recover input (allow some numeric tolerance) + np.testing.assert_allclose( + keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(input_tensor), rtol=1e-5, atol=1e-5 + ) + + # log_det should be scalar or batched scalar + if len(log_det.shape) > 0: + assert log_det.shape[0] == input_tensor.shape[0] # batch dim + else: + assert log_det.shape == () + + # log_det_inv should be negative of log_det (det(inv) = 1/det) + log_det_np = keras.ops.convert_to_numpy(log_det) + log_det_inv_np = keras.ops.convert_to_numpy(log_det_inv) + np.testing.assert_allclose(log_det_inv_np, -log_det_np, rtol=1e-5, atol=1e-5) + + +def test_random_permutation_build_and_call(input_tensor): + layer = RandomPermutation() + input_shape = keras.ops.shape(input_tensor) + layer.build(input_shape) + + # Assert forward_indices and inverse_indices are set and consistent + fwd = keras.ops.convert_to_numpy(layer.forward_indices) + inv = keras.ops.convert_to_numpy(layer.inverse_indices) + # Applying inv on fwd must yield ordered indices + reordered = fwd[inv] + np.testing.assert_array_equal(np.arange(len(fwd)), reordered) + + z, log_det = layer(input_tensor) + x_inv, log_det_inv = layer(z, inverse=True) + + assert z.shape == input_tensor.shape + assert x_inv.shape == input_tensor.shape + np.testing.assert_allclose(keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(input_tensor), atol=1e-5) + assert tuple(log_det.shape) == input_shape[:-1] + assert tuple(log_det_inv.shape) == input_shape[:-1] + + +def test_swap_build_and_call(input_tensor): + layer = Swap() + input_shape = keras.ops.shape(input_tensor) + layer.build(input_shape) + + fwd = keras.ops.convert_to_numpy(layer.forward_indices) + inv = keras.ops.convert_to_numpy(layer.inverse_indices) + reordered = fwd[inv] + np.testing.assert_array_equal(np.arange(len(fwd)), reordered) + + z, log_det = layer(input_tensor) + x_inv, log_det_inv = layer(z, inverse=True) + + assert z.shape == input_tensor.shape + assert x_inv.shape == input_tensor.shape + np.testing.assert_allclose(keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(input_tensor), atol=1e-5) + assert tuple(log_det.shape) == input_shape[:-1] + assert tuple(log_det_inv.shape) == input_shape[:-1] diff --git a/tests/test_networks/test_embeddings.py b/tests/test_networks/test_embeddings.py new file mode 100644 index 000000000..7385d94c0 --- /dev/null +++ b/tests/test_networks/test_embeddings.py @@ -0,0 +1,85 @@ +import pytest +import keras + +from bayesflow.networks.embeddings import ( + FourierEmbedding, + RecurrentEmbedding, + Time2Vec, +) + + +def test_fourier_embedding_output_shape_and_type(): + embed_dim = 8 + batch_size = 4 + + emb_layer = FourierEmbedding(embed_dim=embed_dim, include_identity=True) + # use keras.ops.zeros with shape (batch_size, 1) and float32 dtype + t = keras.ops.zeros((batch_size, 1), dtype="float32") + + emb = emb_layer(t) + # Expected shape is (batch_size, embed_dim + 1) if include_identity else (batch_size, embed_dim) + expected_dim = embed_dim + 1 + assert emb.shape[0] == batch_size + assert emb.shape[1] == expected_dim + # Check type - it should be a Keras tensor, convert to numpy for checking + np_emb = keras.ops.convert_to_numpy(emb) + assert np_emb.shape == (batch_size, expected_dim) + + +def test_fourier_embedding_without_identity(): + embed_dim = 8 + batch_size = 3 + + emb_layer = FourierEmbedding(embed_dim=embed_dim, include_identity=False) + t = keras.ops.zeros((batch_size, 1), dtype="float32") + + emb = emb_layer(t) + expected_dim = embed_dim + assert emb.shape[0] == batch_size + assert emb.shape[1] == expected_dim + + +def test_fourier_embedding_raises_for_odd_embed_dim(): + with pytest.raises(ValueError): + FourierEmbedding(embed_dim=7) + + +def test_recurrent_embedding_lstm_and_gru_shapes(): + batch_size = 2 + seq_len = 5 + dim = 3 + embed_dim = 6 + + # Dummy input + x = keras.ops.zeros((batch_size, seq_len, dim), dtype="float32") + + # lstm + lstm_layer = RecurrentEmbedding(embed_dim=embed_dim, embedding="lstm") + emb_lstm = lstm_layer(x) + # Check the concatenated shape: last dimension = original dim + embed_dim + assert emb_lstm.shape == (batch_size, seq_len, dim + embed_dim) + + # gru + gru_layer = RecurrentEmbedding(embed_dim=embed_dim, embedding="gru") + emb_gru = gru_layer(x) + assert emb_gru.shape == (batch_size, seq_len, dim + embed_dim) + + +def test_recurrent_embedding_raises_unknown_embedding(): + with pytest.raises(ValueError): + RecurrentEmbedding(embed_dim=4, embedding="unknown") + + +def test_time2vec_shapes_and_output(): + batch_size = 3 + seq_len = 7 + dim = 2 + num_periodic_features = 4 + + x = keras.ops.zeros((batch_size, seq_len, dim), dtype="float32") + time2vec_layer = Time2Vec(num_periodic_features=num_periodic_features) + + emb = time2vec_layer(x) + # The last dimension should be dim + num_periodic_features + 1 (trend + periodic) + expected_dim = dim + num_periodic_features + 1 + assert emb.shape == (batch_size, seq_len, expected_dim) From 42fa0358b1ceca48c3f6fdf868c6908ed2d77f0b Mon Sep 17 00:00:00 2001 From: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Date: Fri, 25 Apr 2025 19:37:03 +0200 Subject: [PATCH 12/46] Automatically run slow tests when main is involved. (#438) In addition, this PR limits the slow test to Windows and Python 3.10. The choices are somewhat arbitrary, my thought was to test the setup not covered as much through use by the devs. --- .github/workflows/tests.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 41a254d1f..f90389f69 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -73,8 +73,11 @@ jobs: pytest -x -m "not slow" - name: Run Slow Tests - # run all slow tests only on manual trigger - if: github.event_name == 'workflow_dispatch' + # Run slow tests on manual trigger and pushes/PRs to main. + # Limit to one OS and Python version to save compute. + # Multiline if statements are weird, https://github.com/orgs/community/discussions/25641, + # but feel free to convert it. + if: ${{ ((github.event_name == 'workflow_dispatch') || (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'pull_request' && github.base_ref == 'main')) && ((matrix.os == 'windows-latest') && (matrix.python-version == '3.10')) }} run: | pytest -m "slow" From 86f2f5b31cb2f62847b0f5322bc4730ef8344c01 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 14:18:12 -0400 Subject: [PATCH 13/46] reintroduce symbolic tensor check in log_sinkhorn --- bayesflow/utils/optimal_transport/log_sinkhorn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 2a65d039e..3538eaeff 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -1,6 +1,7 @@ import keras from .. import logging +from ..tensor_utils import is_symbolic_tensor from .euclidean import euclidean @@ -26,6 +27,9 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16) + if is_symbolic_tensor(log_plan): + return log_plan + def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) From 206d70617057ae7763a25fec380c785c82b6333c Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 25 Apr 2025 14:58:31 -0400 Subject: [PATCH 14/46] Update dispatch --- bayesflow/distributions/diagonal_normal.py | 47 ++++++++-------- bayesflow/distributions/diagonal_student_t.py | 54 +++++++++---------- bayesflow/distributions/mixture.py | 28 ++++++---- bayesflow/utils/dispatch/__init__.py | 3 ++ .../dispatch}/find_distribution.py | 17 ++++-- .../utils/dispatch/find_inference_network.py | 39 ++++++++++++++ .../utils/dispatch/find_summary_network.py | 49 +++++++++++++++++ bayesflow/utils/workflow_utils.py | 44 --------------- 8 files changed, 170 insertions(+), 111 deletions(-) rename bayesflow/{distributions => utils/dispatch}/find_distribution.py (60%) create mode 100644 bayesflow/utils/dispatch/find_inference_network.py create mode 100644 bayesflow/utils/dispatch/find_summary_network.py delete mode 100644 bayesflow/utils/workflow_utils.py diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 0d439f704..98a127b1c 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -3,6 +3,7 @@ import numpy as np import keras +from keras import ops from bayesflow.types import Shape, Tensor from bayesflow.utils.decorators import allow_batch_size @@ -19,7 +20,7 @@ def __init__( self, mean: int | float | np.ndarray | Tensor = 0.0, std: int | float | np.ndarray | Tensor = 1.0, - use_learnable_parameters: bool = False, + trainable_parameters: bool = False, seed_generator: keras.random.SeedGenerator = None, **kwargs, ): @@ -39,7 +40,7 @@ def __init__( std : int, float, np.ndarray, or Tensor, optional The standard deviation of the Gaussian distribution. Can be a scalar or a tensor. Default is 1.0. - use_learnable_parameters : bool, optional + trainable_parameters : bool, optional Whether to treat the mean and standard deviation as learnable parameters. Default is False. seed_generator : keras.random.SeedGenerator, optional A Keras seed generator for reproducible random sampling. If None, a new seed @@ -53,47 +54,41 @@ def __init__( self.mean = mean self.std = std + self.trainable_parameters = trainable_parameters + self.seed_generator = seed_generator or keras.random.SeedGenerator() + self.dim = None self.log_normalization_constant = None - - self.use_learnable_parameters = use_learnable_parameters - - if seed_generator is None: - seed_generator = keras.random.SeedGenerator() - - self.seed_generator = seed_generator + self._mean = None + self._std = None def build(self, input_shape: Shape) -> None: + if self.built: + return + self.dim = int(input_shape[-1]) - self.mean = keras.ops.broadcast_to(self.mean, (self.dim,)) - self.mean = keras.ops.cast(self.mean, "float32") - self.std = keras.ops.broadcast_to(self.std, (self.dim,)) - self.std = keras.ops.cast(self.std, "float32") + self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32") + self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32") - self.log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - keras.ops.sum( - keras.ops.log(self.std) - ) + self.log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self.std)) - if self.use_learnable_parameters: + if self.trainable_parameters: self._mean = self.add_weight( - shape=keras.ops.shape(self.mean), - # Initializing with const tensor https://github.com/keras-team/keras/pull/20457#discussion_r1832081248 - initializer=keras.initializers.get(value=self.mean), + shape=ops.shape(self.mean), + initializer=keras.initializers.get(self.mean), dtype="float32", + trainable=True, ) self._std = self.add_weight( - shape=keras.ops.shape(self.std), - # Initializing with const tensor https://github.com/keras-team/keras/pull/20457#discussion_r1832081248 - initializer=keras.initializers.get(self.std), - dtype="float32", + shape=ops.shape(self.std), initializer=keras.initializers.get(self.std), dtype="float32", trainable=True ) else: self._mean = self.mean self._std = self.std def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: - result = -0.5 * keras.ops.sum((samples - self._mean) ** 2 / self.std**2, axis=-1) + result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: result += self.log_normalization_constant @@ -110,7 +105,7 @@ def get_config(self): config = { "mean": self.mean, "std": self.std, - "use_learnable_parameters": self.use_learnable_parameters, + "trainable_parameters": self.trainable_parameters, "seed_generator": self.seed_generator, } diff --git a/bayesflow/distributions/diagonal_student_t.py b/bayesflow/distributions/diagonal_student_t.py index 977fa057b..cd32a67fb 100644 --- a/bayesflow/distributions/diagonal_student_t.py +++ b/bayesflow/distributions/diagonal_student_t.py @@ -1,8 +1,10 @@ -import keras - import math + import numpy as np +import keras +from keras import ops + from bayesflow.types import Shape, Tensor from bayesflow.utils import expand_tile from bayesflow.utils.decorators import allow_batch_size @@ -20,7 +22,7 @@ def __init__( df: int | float, loc: int | float | np.ndarray | Tensor = 0.0, scale: int | float | np.ndarray | Tensor = 1.0, - use_learnable_parameters: bool = False, + trainable_parameters: bool = False, seed_generator: keras.random.SeedGenerator = None, **kwargs, ): @@ -42,8 +44,8 @@ def __init__( The location parameter (mean) of the distribution. Default is 0.0. scale : int, float, np.ndarray, or Tensor, optional The scale parameter (standard deviation) of the distribution. Default is 1.0. - use_learnable_parameters : bool, optional - Whether to treat `loc` and `scale` as learnable parameters. Default is False. + trainable_parameters : bool, optional + Whether to treat `loc` and `scale` as trainable parameters. Default is False. seed_generator : keras.random.SeedGenerator, optional A Keras seed generator for reproducible random sampling. If None, a new seed generator is created. Default is None. @@ -57,52 +59,50 @@ def __init__( self.loc = loc self.scale = scale - self.dim = None - self.log_normalization_constant = None + self.trainable_parameters = trainable_parameters - self.use_learnable_parameters = use_learnable_parameters + self.seed_generator = seed_generator or keras.random.SeedGenerator() - if seed_generator is None: - seed_generator = keras.random.SeedGenerator() - - self.seed_generator = seed_generator + self.log_normalization_constant = None + self.dim = None + self._loc = None + self._scale = None def build(self, input_shape: Shape) -> None: + if self.built: + return + self.dim = int(input_shape[-1]) # convert to tensor and broadcast if necessary - self.loc = keras.ops.broadcast_to(self.loc, (self.dim,)) - self.loc = keras.ops.cast(self.loc, "float32") - - self.scale = keras.ops.broadcast_to(self.scale, (self.dim,)) - self.scale = keras.ops.cast(self.scale, "float32") + self.loc = ops.cast(ops.broadcast_to(self.loc, (self.dim,)), "float32") + self.scale = ops.cast(ops.broadcast_to(self.scale, (self.dim,)), "float32") self.log_normalization_constant = ( -0.5 * self.dim * math.log(self.df) - 0.5 * self.dim * math.log(math.pi) - math.lgamma(0.5 * self.df) + math.lgamma(0.5 * (self.df + self.dim)) - - keras.ops.sum(keras.ops.log(self.scale)) + - ops.sum(keras.ops.log(self.scale)) ) - if self.use_learnable_parameters: + if self.trainable_parameters: self._loc = self.add_weight( - shape=keras.ops.shape(self.loc), - initializer=keras.initializers.get(self.loc), - dtype="float32", + shape=ops.shape(self.loc), initializer=keras.initializers.get(self.loc), dtype="float32", trainable=True ) self._scale = self.add_weight( - shape=keras.ops.shape(self.scale), + shape=ops.shape(self.scale), initializer=keras.initializers.get(self.scale), dtype="float32", + trainable=True, ) else: self._loc = self.loc self._scale = self.scale def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: - mahalanobis_term = keras.ops.sum((samples - self._loc) ** 2 / self._scale**2, axis=-1) - result = -0.5 * (self.df + self.dim) * keras.ops.log1p(mahalanobis_term / self.df) + mahalanobis_term = ops.sum((samples - self._loc) ** 2 / self._scale**2, axis=-1) + result = -0.5 * (self.df + self.dim) * ops.log1p(mahalanobis_term / self.df) if normalize: result += self.log_normalization_constant @@ -122,7 +122,7 @@ def sample(self, batch_shape: Shape) -> Tensor: normal_samples = keras.random.normal(batch_shape + (self.dim,), seed=self.seed_generator) - return self._loc + self._scale * normal_samples * keras.ops.sqrt(self.df / chi2_samples) + return self._loc + self._scale * normal_samples * ops.sqrt(self.df / chi2_samples) def get_config(self): base_config = super().get_config() @@ -131,7 +131,7 @@ def get_config(self): "df": self.df, "loc": self.loc, "scale": self.scale, - "use_learnable_parameters": self.use_learnable_parameters, + "trainable_parameters": self.trainable_parameters, "seed_generator": self.seed_generator, } diff --git a/bayesflow/distributions/mixture.py b/bayesflow/distributions/mixture.py index 0946f72b7..d7f6bd758 100644 --- a/bayesflow/distributions/mixture.py +++ b/bayesflow/distributions/mixture.py @@ -50,22 +50,18 @@ def __init__( super().__init__(**kwargs) - self.dim = None self.distributions = distributions if mixture_logits is None: - mixture_logits = keras.ops.ones(shape=len(distributions)) - - self.mixture_logits = mixture_logits - self._mixture_logits = self.add_weight( - shape=(len(distributions),), - initializer=keras.initializers.Constant(value=mixture_logits), - dtype="float32", - trainable=trainable_mixture, - ) + self.mixture_logits = ops.ones(shape=len(distributions)) + else: + self.mixture_logits = ops.convert_to_tensor(mixture_logits) self.trainable_mixture = trainable_mixture + self.dim = None + self._mixture_logits = None + @allow_batch_size def sample(self, batch_shape: Shape) -> Tensor: """ @@ -138,10 +134,20 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: return log_prob def build(self, input_shape: Shape) -> None: + if self.built: + return + + self.dim = input_shape[-1] + for distribution in self.distributions: distribution.build(input_shape) - self.dim = input_shape[-1] + self._mixture_logits = self.add_weight( + shape=(len(self.distributions),), + initializer=keras.initializers.get(self.mixture_logits), + dtype="float32", + trainable=self.trainable_mixture, + ) def get_config(self): base_config = super().get_config() diff --git a/bayesflow/utils/dispatch/__init__.py b/bayesflow/utils/dispatch/__init__.py index 422f014e0..852756780 100644 --- a/bayesflow/utils/dispatch/__init__.py +++ b/bayesflow/utils/dispatch/__init__.py @@ -3,3 +3,6 @@ from .find_permutation import find_permutation from .find_pooling import find_pooling from .find_recurrent_net import find_recurrent_net +from .find_inference_network import find_inference_network +from .find_summary_network import find_summary_network +from .find_distribution import find_distribution diff --git a/bayesflow/distributions/find_distribution.py b/bayesflow/utils/dispatch/find_distribution.py similarity index 60% rename from bayesflow/distributions/find_distribution.py rename to bayesflow/utils/dispatch/find_distribution.py index 84ef56c15..f94a9f262 100644 --- a/bayesflow/distributions/find_distribution.py +++ b/bayesflow/utils/dispatch/find_distribution.py @@ -1,6 +1,5 @@ from functools import singledispatch - -from bayesflow.distributions import Distribution +import keras @singledispatch @@ -15,8 +14,20 @@ def _(name: str, *args, **kwargs): from bayesflow.distributions import DiagonalNormal distribution = DiagonalNormal(*args, **kwargs) + + case "student" | "student-t" | "student_t": + from bayesflow.distributions import DiagonalStudentT + + distribution = DiagonalStudentT(*args, **kwargs) + + case "mixture": + raise ValueError( + "Mixture distributions need to be explicitly defined as bf.distributions.Mixture(...) " + "and passed to the constructor." + ) case "none": distribution = None + case other: raise ValueError(f"Unsupported distribution name '{other}'.") @@ -29,5 +40,5 @@ def _(none: None, *args, **kwargs): @find_distribution.register -def _(distribution: Distribution, *args, **kwargs): +def _(distribution: keras.Layer, *args, **kwargs): return distribution diff --git a/bayesflow/utils/dispatch/find_inference_network.py b/bayesflow/utils/dispatch/find_inference_network.py new file mode 100644 index 000000000..617018de3 --- /dev/null +++ b/bayesflow/utils/dispatch/find_inference_network.py @@ -0,0 +1,39 @@ +from functools import singledispatch +import keras + + +@singledispatch +def find_inference_network(arg, *args, **kwargs): + raise TypeError(f"Cannot infer inference network from {arg!r}.") + + +@find_inference_network.register +def _(name: str, *args, **kwargs): + match name.lower(): + case "coupling_flow": + from bayesflow.networks import CouplingFlow + + return CouplingFlow(*args, **kwargs) + + case "flow_matching": + from bayesflow.networks import FlowMatching + + return FlowMatching(*args, **kwargs) + + case "consistency_model": + from bayesflow.networks import ConsistencyModel + + return ConsistencyModel(*args, **kwargs) + + case unknown_network: + raise ValueError(f"Unknown inference network: '{unknown_network}'") + + +@find_inference_network.register +def _(layer: keras.Layer, *args, **kwargs): + return layer + + +@find_inference_network.register +def _(model: keras.Model, *args, **kwargs): + return model diff --git a/bayesflow/utils/dispatch/find_summary_network.py b/bayesflow/utils/dispatch/find_summary_network.py new file mode 100644 index 000000000..bc14b7e21 --- /dev/null +++ b/bayesflow/utils/dispatch/find_summary_network.py @@ -0,0 +1,49 @@ +from functools import singledispatch +import keras + + +@singledispatch +def find_summary_network(arg, *args, **kwargs): + raise TypeError(f"Cannot infer inference network from {arg!r}.") + + +@find_summary_network.register +def _(name: str, *args, **kwargs): + match name.lower(): + case "deep_set": + from bayesflow.networks import DeepSet + + return DeepSet(*args, **kwargs) + + case "set_transformer": + from bayesflow.networks import SetTransformer + + return SetTransformer(*args, **kwargs) + + case "fusion_transformer": + from bayesflow.networks import FusionTransformer + + return FusionTransformer(*args, **kwargs) + + case "time_series_transformer": + from bayesflow.networks import TimeSeriesTransformer + + return TimeSeriesTransformer(*args, **kwargs) + + case "time_series_network": + from bayesflow.networks import TimeSeriesNetwork + + return TimeSeriesNetwork(*args, **kwargs) + + case unknown_network: + raise ValueError(f"Unknown summary network: '{unknown_network}'") + + +@find_summary_network.register +def _(layer: keras.Layer, *args, **kwargs): + return layer + + +@find_summary_network.register +def _(model: keras.Model, *args, **kwargs): + return model diff --git a/bayesflow/utils/workflow_utils.py b/bayesflow/utils/workflow_utils.py deleted file mode 100644 index 0f23a8cb8..000000000 --- a/bayesflow/utils/workflow_utils.py +++ /dev/null @@ -1,44 +0,0 @@ -import bayesflow.networks -from bayesflow.networks import InferenceNetwork, PointInferenceNetwork, SummaryNetwork - - -def find_inference_network(inference_network: InferenceNetwork | str, **kwargs) -> InferenceNetwork: - if isinstance(inference_network, InferenceNetwork) or isinstance(inference_network, PointInferenceNetwork): - return inference_network - if isinstance(inference_network, type): - return inference_network(**kwargs) - - match inference_network.lower(): - case "coupling_flow": - return bayesflow.networks.CouplingFlow(**kwargs) - case "flow_matching": - return bayesflow.networks.FlowMatching(**kwargs) - case "consistency_model": - return bayesflow.networks.ConsistencyModel(**kwargs) - case str() as unknown_network: - raise ValueError(f"Unknown inference network: '{unknown_network}'") - case other: - raise TypeError(f"Unknown transform type: {other}") - - -def find_summary_network(summary_network: SummaryNetwork | str, **kwargs) -> SummaryNetwork: - if isinstance(summary_network, SummaryNetwork): - return summary_network - if isinstance(summary_network, type): - return summary_network(**kwargs) - - match summary_network.lower(): - case "deep_set": - return bayesflow.networks.DeepSet(**kwargs) - case "set_transformer": - return bayesflow.networks.SetTransformer(**kwargs) - case "fusion_transformer": - return bayesflow.networks.FusionTransformer(**kwargs) - case "time_series_transformer": - return bayesflow.networks.TimeSeriesTransformer(**kwargs) - case "time_series_network": - return bayesflow.networks.LSTNet(**kwargs) - case str() as unknown_network: - raise ValueError(f"Unknown summary network: '{unknown_network}'") - case other: - raise TypeError(f"Unknown transform type: {other}") From 25f5c64fe5c2976015b5cf02bc46aca357c4dbad Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 25 Apr 2025 14:58:56 -0400 Subject: [PATCH 15/46] Update dispatching distributions --- bayesflow/networks/inference_network.py | 3 +-- bayesflow/networks/summary_network.py | 3 +-- bayesflow/utils/__init__.py | 27 +++++++++++++++++++++++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index ae4856b02..b092ce2cb 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -1,8 +1,7 @@ import keras -from bayesflow.distributions import find_distribution from bayesflow.types import Shape, Tensor -from bayesflow.utils import layer_kwargs +from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import allow_batch_size diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index 6e97c618f..e821be3f3 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -1,9 +1,8 @@ import keras -from bayesflow.distributions import find_distribution from bayesflow.metrics.functional import maximum_mean_discrepancy from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs +from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import deserialize diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 73ba7fd8b..737c533ce 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -7,8 +7,11 @@ logging, numpy_utils, ) + from .callbacks import detailed_loss_callback + from .devices import devices + from .dict_utils import ( convert_args, convert_kwargs, @@ -20,30 +23,48 @@ split_arrays, squeeze_inner_estimates_dict, ) -from .dispatch import find_network, find_permutation, find_pooling, find_recurrent_net + +from .dispatch import ( + find_network, + find_permutation, + find_pooling, + find_recurrent_net, + find_summary_network, + find_inference_network, + find_distribution, +) + from .ecdf import simultaneous_ecdf_bands, ranks + from .functional import batched_call + from .git import ( issue_url, pull_url, repo_url, ) + from .hparam_utils import find_batch_size, find_memory_budget + from .integrate import ( integrate, ) + from .io import ( pickle_load, format_bytes, parse_bytes, ) + from .jacobian import ( jacobian, jacobian_trace, jvp, vjp, ) + from .optimal_transport import optimal_transport + from .plot_utils import ( check_estimates_prior_shapes, prepare_plot_data, @@ -53,6 +74,7 @@ add_metric, ) from .serialization import serialize_value_or_type, deserialize_value_or_type + from .tensor_utils import ( concatenate_valid, expand, @@ -75,9 +97,10 @@ fill_triangular_matrix, weighted_mean, ) + from .classification import calibration_curve, confusion_matrix + from .validators import check_lengths_same -from .workflow_utils import find_inference_network, find_summary_network from ._docs import _add_imports_to_all From f6a70b5a318e97bb08078dcc3646f77de3cdf84e Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 25 Apr 2025 14:59:22 -0400 Subject: [PATCH 16/46] Improve workflow tests with multiple summary nets / approximators --- tests/test_workflows/conftest.py | 53 +++++++++++++++++---- tests/test_workflows/test_basic_workflow.py | 23 +++++++-- 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/tests/test_workflows/conftest.py b/tests/test_workflows/conftest.py index e9455800b..c98e543e9 100644 --- a/tests/test_workflows/conftest.py +++ b/tests/test_workflows/conftest.py @@ -1,15 +1,52 @@ import pytest +import keras -@pytest.fixture() -def inference_network(): - from bayesflow.networks import CouplingFlow +from bayesflow.utils.serialization import serializable - return CouplingFlow(depth=2) +@pytest.fixture(params=["coupling_flow", "flow_matching"]) +def inference_network(request): + if request.param == "coupling_flow": + from bayesflow.networks import CouplingFlow -@pytest.fixture() -def summary_network(): - from bayesflow.networks import TimeSeriesTransformer + return CouplingFlow(depth=2) - return TimeSeriesTransformer(embed_dims=(8, 8), mlp_widths=(32, 32), mlp_depths=(1, 1)) + elif request.param == "flow_matching": + from bayesflow.networks import FlowMatching + + return FlowMatching(subnet_kwargs=dict(widths=(32, 32)), use_optimal_transport=False) + + +@pytest.fixture(params=["time_series_transformer", "fusion_transformer", "time_series_network", "custom"]) +def summary_network(request): + if request.param == "time_series_transformer": + from bayesflow.networks import TimeSeriesTransformer + + return TimeSeriesTransformer(embed_dims=(8, 8), mlp_widths=(16, 8), mlp_depths=(1, 1)) + + elif request.param == "fusion_transformer": + from bayesflow.networks import FusionTransformer + + return FusionTransformer( + embed_dims=(8, 8), mlp_widths=(8, 16), mlp_depths=(2, 1), template_dim=8, bidirectional=False + ) + + elif request.param == "time_series_network": + from bayesflow.networks import TimeSeriesNetwork + + return TimeSeriesNetwork(filters=4, skip_steps=2) + + elif request.param == "custom": + from bayesflow.networks import SummaryNetwork + + @serializable + class Custom(SummaryNetwork): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.inner = keras.Sequential([keras.layers.LSTM(8), keras.layers.Dense(4)]) + + def call(self, x, **kwargs): + return self.inner(x, training=kwargs.get("stage") == "training") + + return Custom() diff --git a/tests/test_workflows/test_basic_workflow.py b/tests/test_workflows/test_basic_workflow.py index 9a1c7815f..a0a3dc83c 100644 --- a/tests/test_workflows/test_basic_workflow.py +++ b/tests/test_workflows/test_basic_workflow.py @@ -1,21 +1,34 @@ +import os + +import keras + import bayesflow as bf -def test_basic_workflow(inference_network, summary_network): +def test_basic_workflow(tmp_path, inference_network, summary_network): workflow = bf.BasicWorkflow( inference_network=inference_network, summary_network=summary_network, inference_variables=["parameters"], summary_variables=["observables"], simulator=bf.simulators.SIR(), + checkpoint_filepath=str(tmp_path), ) - history = workflow.fit_online(epochs=2, batch_size=32, num_batches_per_epoch=2) - plots = workflow.plot_default_diagnostics(test_data=50, num_samples=50) - metrics = workflow.compute_default_diagnostics(test_data=50, num_samples=50, variable_names=["p1", "p2"]) + # Ensure metrics work fine + history = workflow.fit_online(epochs=4, batch_size=8, num_batches_per_epoch=2, verbose=0) + plots = workflow.plot_default_diagnostics(test_data=50, num_samples=25) + metrics = workflow.compute_default_diagnostics(test_data=50, num_samples=25, variable_names=["p1", "p2"]) assert "loss" in list(history.history.keys()) - assert len(history.history["loss"]) == 2 + assert len(history.history["loss"]) == 4 assert list(plots.keys()) == ["losses", "recovery", "calibration_ecdf", "z_score_contraction"] assert list(metrics.columns) == ["p1", "p2"] assert metrics.values.shape == (3, 2) + + # Ensure saving and loading from workflow works fine + loaded_approximator = keras.saving.load_model(os.path.join(str(tmp_path), "model.keras")) + + # Get samples + samples = loaded_approximator.sample(conditions=workflow.simulate(5), num_samples=3) + assert samples["parameters"].shape == (5, 3, 2) From 7ce37cfa19f8e6dba553c71e37dba5e7924fb056 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 25 Apr 2025 15:22:34 -0400 Subject: [PATCH 17/46] Fix zombie find_distribution import --- bayesflow/distributions/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bayesflow/distributions/__init__.py b/bayesflow/distributions/__init__.py index ed8d05af7..e9e30b7e4 100644 --- a/bayesflow/distributions/__init__.py +++ b/bayesflow/distributions/__init__.py @@ -9,8 +9,6 @@ from .diagonal_student_t import DiagonalStudentT from .mixture import Mixture -from .find_distribution import find_distribution - from ..utils._docs import _add_imports_to_all _add_imports_to_all(include_modules=[]) From ea5a78db5cb1bc04039fbf076c94d866822843c0 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 25 Apr 2025 15:47:20 -0400 Subject: [PATCH 18/46] Add readme entry [no ci] --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6f1df3a4b..a42b1df97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "bayesflow" -version = "2.0.1" +version = "2.0.2" authors = [{ name = "The BayesFlow Team" }] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -19,6 +19,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] description = "Amortizing Bayesian Inference With Neural Networks" +readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } requires-python = ">= 3.10, < 3.12" From dc3cf816dfa0ac66dfb23df1afabd617df130cb9 Mon Sep 17 00:00:00 2001 From: Marvin Schmitt <35921281+marvinschmitt@users.noreply.github.com> Date: Fri, 25 Apr 2025 22:50:58 +0300 Subject: [PATCH 19/46] Update README: NumFOCUS affiliation, awesome-abi list (#445) --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 3f3269754..c4049c470 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ ![Codecov](https://img.shields.io/codecov/c/github/bayesflow-org/bayesflow?style=for-the-badge&link=https%3A%2F%2Fapp.codecov.io%2Fgh%2Fbayesflow-org%2Fbayesflow%2Ftree%2Fmain) [![DOI](https://img.shields.io/badge/DOI-10.21105%2Fjoss.05702-blue?style=for-the-badge)](https://doi.org/10.21105/joss.05702) ![PyPI - License](https://img.shields.io/pypi/l/bayesflow?style=for-the-badge) +![NumFOCUS Affiliated Project](https://img.shields.io/badge/NumFOCUS-Affiliated%20Project-orange?style=for-the-badge) BayesFlow is a Python library for simulation-based **Amortized Bayesian Inference** with neural networks. It provides users and researchers with: @@ -225,8 +226,10 @@ You can find and install the old Bayesflow version via the `stable-legacy` branc ## Awesome Amortized Inference -If you are interested in a curated list of resources, including reviews, software, papers, and other resources related to amortized inference, feel free to explore our [community-driven list](https://github.com/bayesflow-org/awesome-amortized-inference). +If you are interested in a curated list of resources, including reviews, software, papers, and other resources related to amortized inference, feel free to explore our [community-driven list](https://github.com/bayesflow-org/awesome-amortized-inference). If you'd like a paper (by yourself or someone else) featured, please add it to the list with a pull request, an issue, or a message to the maintainers. ## Acknowledgments This project is currently managed by researchers from Rensselaer Polytechnic Institute, TU Dortmund University, and Heidelberg University. It is partially funded by the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) Projects 528702768 and 508399956. The project is further supported by Germany's Excellence Strategy -- EXC-2075 - 390740016 (Stuttgart Cluster of Excellence SimTech) and EXC-2181 - 390900948 (Heidelberg Cluster of Excellence STRUCTURES), the collaborative research cluster TRR 391 – 520388526, as well as the Informatics for Life initiative funded by the Klaus Tschira Foundation. + +BayesFlow is a [NumFOCUS Affiliated Project](https://numfocus.org/sponsored-projects/affiliated-projects). From 3b1c0530b59e55666c5e8bf9d6f36104766fca5c Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 17:22:57 -0400 Subject: [PATCH 20/46] fix is_symbolic_tensor --- bayesflow/utils/tensor_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/bayesflow/utils/tensor_utils.py b/bayesflow/utils/tensor_utils.py index 4d89249b7..72d83076c 100644 --- a/bayesflow/utils/tensor_utils.py +++ b/bayesflow/utils/tensor_utils.py @@ -97,9 +97,6 @@ def is_symbolic_tensor(x: Tensor) -> bool: if keras.utils.is_keras_tensor(x): return True - if not keras.ops.is_tensor(x): - return False - match keras.backend.backend(): case "jax": import jax From c638124244e9d8b7e49353b53fbe22dae98de11f Mon Sep 17 00:00:00 2001 From: Lars Date: Fri, 25 Apr 2025 17:30:29 -0400 Subject: [PATCH 21/46] remove multiple batch sizes, remove multiple python version tests, remove update-workflows branch from workflow style tests, add __init__ and conftest to test_point_approximators (#443) --- .github/workflows/style.yaml | 2 -- .github/workflows/tests.yaml | 2 +- tests/conftest.py | 32 +------------------ .../test_point_approximators/__init__.py | 0 .../test_point_approximators/conftest.py | 0 tests/test_distributions/conftest.py | 2 +- tests/test_links/conftest.py | 5 --- tests/test_networks/test_summary_networks.py | 2 +- tests/utils/check_combinations.py | 6 ++-- 9 files changed, 7 insertions(+), 44 deletions(-) create mode 100644 tests/test_approximators/test_point_approximators/__init__.py create mode 100644 tests/test_approximators/test_point_approximators/conftest.py diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index a451ac89d..3c2da4421 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -6,12 +6,10 @@ on: branches: - main - dev - - update-workflows push: branches: - main - dev - - update-workflows jobs: check-code-style: diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index f90389f69..ab3d03078 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -24,7 +24,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: ["3.10", "3.11"] + python-version: ["3.10"] # we usually only need to test the oldest python version backend: ["jax", "tensorflow", "torch"] runs-on: ${{ matrix.os }} diff --git a/tests/conftest.py b/tests/conftest.py index 6e1e69db1..560b7c59b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,7 @@ def pytest_make_parametrize_id(config, val, argname): return f"{argname}={repr(val)}" -@pytest.fixture(params=[2, 3], scope="session") +@pytest.fixture(params=[2], scope="session") def batch_size(request): return request.param @@ -94,33 +94,3 @@ def random_set(batch_size, set_size, feature_size): @pytest.fixture(params=[2, 3], scope="session") def set_size(request): return request.param - - -@pytest.fixture(params=["two_moons"], scope="session") -def simulator(request): - return request.getfixturevalue(request.param) - - -@pytest.fixture(scope="session") -def training_dataset(simulator, batch_size): - from bayesflow.datasets import OfflineDataset - - num_batches = 128 - samples = simulator.sample((num_batches * batch_size,)) - return OfflineDataset(samples, batch_size=batch_size) - - -@pytest.fixture(scope="session") -def two_moons(batch_size): - from bayesflow.simulators import TwoMoonsSimulator - - return TwoMoonsSimulator() - - -@pytest.fixture(scope="session") -def validation_dataset(simulator, batch_size): - from bayesflow.datasets import OfflineDataset - - num_batches = 16 - samples = simulator.sample((num_batches * batch_size,)) - return OfflineDataset(samples, batch_size=batch_size) diff --git a/tests/test_approximators/test_point_approximators/__init__.py b/tests/test_approximators/test_point_approximators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_approximators/test_point_approximators/conftest.py b/tests/test_approximators/test_point_approximators/conftest.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_distributions/conftest.py b/tests/test_distributions/conftest.py index 29c5b4139..e06ed18af 100644 --- a/tests/test_distributions/conftest.py +++ b/tests/test_distributions/conftest.py @@ -3,7 +3,7 @@ import keras -@pytest.fixture(params=[2, 3]) +@pytest.fixture(params=[2]) def batch_size(request): return request.param diff --git a/tests/test_links/conftest.py b/tests/test_links/conftest.py index 53e9eeac8..be7730ef2 100644 --- a/tests/test_links/conftest.py +++ b/tests/test_links/conftest.py @@ -82,8 +82,3 @@ def quantiles(request): @pytest.fixture() def unordered(batch_size, num_quantiles, num_variables): return keras.random.normal((batch_size, num_quantiles, num_variables)) - - -# @pytest.fixture() -# def random_matrix_batch(batch_size, num_variables): -# return keras.random.normal((batch_size, num_variables, num_variables)) diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index 50e1726c1..74ce1f5fd 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -103,7 +103,7 @@ def test_save_and_load(tmp_path, summary_network, random_set): @pytest.mark.parametrize("stage", ["training", "validation"]) def test_compute_metrics(stage, summary_network, random_set): if summary_network is None: - pytest.skip() + pytest.skip("Nothing to do, because there is no summary network.") summary_network.build(keras.ops.shape(random_set)) diff --git a/tests/utils/check_combinations.py b/tests/utils/check_combinations.py index 8d3fa5d46..8565703c8 100644 --- a/tests/utils/check_combinations.py +++ b/tests/utils/check_combinations.py @@ -13,12 +13,12 @@ def check_combination_simulator_adapter(simulator, adapter): with pytest.raises(KeyError): adapter(simulator.sample(1)) # Don't use this fixture combination for further tests. - pytest.skip() + pytest.skip(reason="Do not use this fixture combination for further tests") # TODO: better reason elif simulator_with_sample_weight and not adapter_with_sample_weight: # When a weight key is present, but the adapter does not configure it # to be used as sample weight, no error is raised currently. # Don't use this fixture combination for further tests. - pytest.skip() + pytest.skip(reason="Do not use this fixture combination for further tests") # TODO: better reason def check_approximator_multivariate_normal_score(approximator): @@ -28,4 +28,4 @@ def check_approximator_multivariate_normal_score(approximator): if isinstance(approximator, PointApproximator): for score in approximator.inference_network.scores.values(): if isinstance(score, MultivariateNormalScore): - pytest.skip() + pytest.skip(reason="MultivariateNormalScore is unstable") From de8e1cb9c9f1dfd1dcc7992b76d40aef7219e41c Mon Sep 17 00:00:00 2001 From: Lars Date: Fri, 25 Apr 2025 17:33:58 -0400 Subject: [PATCH 22/46] implement compile_from_config and get_compile_config (#442) * implement compile_from_config and get_compile_config * add optimizer build to compile_from_config --- .../approximators/continuous_approximator.py | 16 ++++++++++++++++ .../model_comparison_approximator.py | 16 ++++++++++++++++ bayesflow/metrics/maximum_mean_discrepancy.py | 2 ++ bayesflow/metrics/root_mean_squard_error.py | 3 ++- 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index dbd9eba0c..f0c1d68fa 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -104,6 +104,12 @@ def compile( return super().compile(*args, **kwargs) + def compile_from_config(self, config): + self.compile(**deserialize(config)) + if hasattr(self, "optimizer") and self.built: + # Create optimizer variables. + self.optimizer.build(self.trainable_variables) + def compute_metrics( self, inference_variables: Tensor, @@ -213,6 +219,16 @@ def get_config(self): return base_config | serialize(config) + def get_compile_config(self): + base_config = super().get_compile_config() or {} + + config = { + "inference_metrics": self.inference_network._metrics, + "summary_metrics": self.summary_network._metrics if self.summary_network is not None else None, + } + + return base_config | serialize(config) + def estimate( self, conditions: Mapping[str, np.ndarray], diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 1e26f00b0..03b377537 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -118,6 +118,12 @@ def compile( return super().compile(*args, **kwargs) + def compile_from_config(self, config): + self.compile(**deserialize(config)) + if hasattr(self, "optimizer") and self.built: + # Create optimizer variables. + self.optimizer.build(self.trainable_variables) + def compute_metrics( self, *, @@ -262,6 +268,16 @@ def get_config(self): return base_config | serialize(config) + def get_compile_config(self): + base_config = super().get_compile_config() or {} + + config = { + "classifier_metrics": self.classifier_network._metrics, + "summary_metrics": self.summary_network._metrics if self.summary_network is not None else None, + } + + return base_config | serialize(config) + def predict( self, *, diff --git a/bayesflow/metrics/maximum_mean_discrepancy.py b/bayesflow/metrics/maximum_mean_discrepancy.py index 64b8c35a0..37af44fd4 100644 --- a/bayesflow/metrics/maximum_mean_discrepancy.py +++ b/bayesflow/metrics/maximum_mean_discrepancy.py @@ -2,9 +2,11 @@ import keras +from bayesflow.utils.serialization import serializable from .functional import maximum_mean_discrepancy +@serializable class MaximumMeanDiscrepancy(keras.Metric): def __init__( self, diff --git a/bayesflow/metrics/root_mean_squard_error.py b/bayesflow/metrics/root_mean_squard_error.py index 13e724c14..97de62e6a 100644 --- a/bayesflow/metrics/root_mean_squard_error.py +++ b/bayesflow/metrics/root_mean_squard_error.py @@ -1,10 +1,11 @@ from functools import partial import keras - +from bayesflow.utils.serialization import serializable from .functional import root_mean_squared_error +@serializable class RootMeanSquaredError(keras.metrics.MeanMetricWrapper): def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs): fn = partial(root_mean_squared_error, **kwargs) From 16491beae5606953762fbccf99fe368fecd2e580 Mon Sep 17 00:00:00 2001 From: Lars Date: Fri, 25 Apr 2025 18:58:54 -0400 Subject: [PATCH 23/46] Fix Optimal Transport for Compiled Contexts (#446) * remove the is_symbolic_tensor check because this would otherwise skip the whole function for compiled contexts * skip pyabc test * fix sinkhorn and log_sinkhorn message formatting for jax by making the warning message worse --- .../utils/optimal_transport/log_sinkhorn.py | 18 +++++------------- bayesflow/utils/optimal_transport/sinkhorn.py | 18 +++++------------- tests/test_examples/test_examples.py | 1 + 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 3538eaeff..9fa6dba26 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -1,7 +1,6 @@ import keras from .. import logging -from ..tensor_utils import is_symbolic_tensor from .euclidean import euclidean @@ -27,9 +26,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16) - if is_symbolic_tensor(log_plan): - return log_plan - def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) @@ -57,22 +53,18 @@ def do_nothing(): pass def log_steps(): - msg = "Log-Sinkhorn-Knopp converged after {:d} steps." + msg = "Log-Sinkhorn-Knopp converged after {} steps." logging.debug(msg, steps) def warn_convergence(): - marginals = keras.ops.logsumexp(log_plan, axis=0) - deviations = keras.ops.abs(marginals) - badness = 100.0 * keras.ops.exp(keras.ops.max(deviations)) - - msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)." + msg = "Log-Sinkhorn-Knopp did not converge after {} steps." - logging.warning(msg, max_steps, badness) + logging.warning(msg, max_steps) def warn_nans(): - msg = "Log-Sinkhorn-Knopp produced NaNs." - logging.warning(msg) + msg = "Log-Sinkhorn-Knopp produced NaNs after {} steps." + logging.warning(msg, steps) keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing) keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence) diff --git a/bayesflow/utils/optimal_transport/sinkhorn.py b/bayesflow/utils/optimal_transport/sinkhorn.py index 1efa5ae0b..04c268eb0 100644 --- a/bayesflow/utils/optimal_transport/sinkhorn.py +++ b/bayesflow/utils/optimal_transport/sinkhorn.py @@ -3,7 +3,6 @@ from bayesflow.types import Tensor from .. import logging -from ..tensor_utils import is_symbolic_tensor from .euclidean import euclidean @@ -76,9 +75,6 @@ def sinkhorn_plan( # initialize the transport plan from a gaussian kernel plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16)) - if is_symbolic_tensor(plan): - return plan - def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) @@ -106,22 +102,18 @@ def do_nothing(): pass def log_steps(): - msg = "Sinkhorn-Knopp converged after {:d} steps." + msg = "Sinkhorn-Knopp converged after {} steps." logging.info(msg, max_steps) def warn_convergence(): - marginals = keras.ops.sum(plan, axis=0) - deviations = keras.ops.abs(marginals - 1.0) - badness = 100.0 * keras.ops.max(deviations) - - msg = "Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)." + msg = "Sinkhorn-Knopp did not converge after {}." - logging.warning(msg, max_steps, badness) + logging.warning(msg, max_steps) def warn_nans(): - msg = "Sinkhorn-Knopp produced NaNs." - logging.warning(msg) + msg = "Sinkhorn-Knopp produced NaNs after {} steps." + logging.warning(msg, steps) keras.ops.cond(contains_nans(plan), warn_nans, do_nothing) keras.ops.cond(is_converged(plan), log_steps, warn_convergence) diff --git a/tests/test_examples/test_examples.py b/tests/test_examples/test_examples.py index 245052636..40135627a 100644 --- a/tests/test_examples/test_examples.py +++ b/tests/test_examples/test_examples.py @@ -9,6 +9,7 @@ def test_bayesian_experimental_design(examples_path): run_notebook(examples_path / "Bayesian_Experimental_Design.ipynb") +@pytest.mark.skip(reason="requires setting up pyabc") @pytest.mark.slow def test_from_abc_to_bayesflow(examples_path): run_notebook(examples_path / "From_ABC_to_BayesFlow.ipynb") From ec0ee2f187efabf6ecf1cfac97a61f5987b2e2b9 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 19:58:40 -0400 Subject: [PATCH 24/46] update dispatch tests for more coverage --- tests/test_utils/test_dispatch.py | 255 +++++++++++++----------------- 1 file changed, 112 insertions(+), 143 deletions(-) diff --git a/tests/test_utils/test_dispatch.py b/tests/test_utils/test_dispatch.py index 85e326445..df25ea78e 100644 --- a/tests/test_utils/test_dispatch.py +++ b/tests/test_utils/test_dispatch.py @@ -1,201 +1,170 @@ import keras import pytest -# Import the dispatch functions -from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net -from tests.utils import assert_allclose +from bayesflow.utils import find_inference_network, find_distribution, find_summary_network -# --- Tests for find_network.py --- +# --- Tests for find_inference_network.py --- -class DummyMLP: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs +class DummyInferenceNetwork: + def __init__(self, *a, **kw): + self.args = a + self.kwargs = kw -def test_find_network_with_string(monkeypatch): - # Monkeypatch the MLP entry in bayesflow.networks - monkeypatch.setattr("bayesflow.networks.MLP", DummyMLP) - - net = find_network("mlp", 1, key="value") - assert isinstance(net, DummyMLP) - assert net.args == (1,) - assert net.kwargs == {"key": "value"} +@pytest.mark.parametrize( + "name,expected_class_path", + [ + ("coupling_flow", "bayesflow.networks.CouplingFlow"), + ("flow_matching", "bayesflow.networks.FlowMatching"), + ("consistency_model", "bayesflow.networks.ConsistencyModel"), + ], +) +def test_find_inference_network_by_name(monkeypatch, name, expected_class_path): + # patch the expected class in bayesflow.networks + components = expected_class_path.split(".") + module_path = ".".join(components[:-1]) + class_name = components[-1] -def test_find_network_with_type(): - class CustomNet: - def __init__(self, x): - self.x = x + dummy_cls = DummyInferenceNetwork + monkeypatch.setattr(f"{module_path}.{class_name}", dummy_cls) - net = find_network(CustomNet, 42) - assert isinstance(net, CustomNet) - assert net.x == 42 + net = find_inference_network(name, 1, key="val") + assert isinstance(net, DummyInferenceNetwork) + assert net.args == (1,) + assert net.kwargs == {"key": "val"} -def test_find_network_with_keras_layer(): +def test_find_inference_network_by_keras_layer(): layer = keras.layers.Dense(10) - returned = find_network(layer) - assert returned is layer - - -def test_find_network_invalid_type(): - with pytest.raises(TypeError): - find_network(123) + result = find_inference_network(layer) + assert result is layer -# --- Tests for find_permutation.py --- +def test_find_inference_network_by_keras_model(): + model = keras.models.Sequential() + result = find_inference_network(model) + assert result is model -class DummyRandomPermutation: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs +def test_find_inference_network_unknown_name(): + with pytest.raises(ValueError): + find_inference_network("unknown_network_name") -class DummySwap: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs +def test_find_inference_network_invalid_type(): + with pytest.raises(TypeError): + find_inference_network(12345) -class DummyOrthogonalPermutation: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs +# --- Tests for find_distribution.py --- -def test_find_permutation_random(monkeypatch): - type("dummy_mod", (), {"RandomPermutation": DummyRandomPermutation}) - monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.RandomPermutation", DummyRandomPermutation) - perm = find_permutation("random", 99, flag=True) - assert isinstance(perm, DummyRandomPermutation) - assert perm.args == (99,) - assert perm.kwargs == {"flag": True} +class DummyDistribution: + def __init__(self, *a, **kw): + self.args = a + self.kwargs = kw @pytest.mark.parametrize( - "name,dummy_cls", - [("swap", DummySwap), ("learnable", DummyOrthogonalPermutation), ("orthogonal", DummyOrthogonalPermutation)], + "name, expected_class_path", + [ + ("normal", "bayesflow.distributions.DiagonalNormal"), + ("student", "bayesflow.distributions.DiagonalStudentT"), + ("student-t", "bayesflow.distributions.DiagonalStudentT"), + ("student_t", "bayesflow.distributions.DiagonalStudentT"), + ], ) -def test_find_permutation_by_name(monkeypatch, name, dummy_cls): - # Inject dummy classes for each permutation type - if name == "swap": - monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.Swap", dummy_cls) - else: - monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.OrthogonalPermutation", dummy_cls) - perm = find_permutation(name, "a", b="c") - assert isinstance(perm, dummy_cls) - assert perm.args == ("a",) - assert perm.kwargs == {"b": "c"} - +def test_find_distribution_by_name(monkeypatch, name, expected_class_path): + components = expected_class_path.split(".") + module_path = ".".join(components[:-1]) + class_name = components[-1] -def test_find_permutation_with_keras_layer(): - layer = keras.layers.Activation("relu") - perm = find_permutation(layer) - assert perm is layer + dummy_cls = DummyDistribution + monkeypatch.setattr(f"{module_path}.{class_name}", dummy_cls) + dist = find_distribution(name, 10, a=5) + assert isinstance(dist, DummyDistribution) + assert dist.args == (10,) + assert dist.kwargs == {"a": 5} -def test_find_permutation_with_none(): - res = find_permutation(None) - assert res is None - - -def test_find_permutation_invalid_type(): - with pytest.raises(TypeError): - find_permutation(3.14) +def test_find_distribution_none_returns_none(): + assert find_distribution(None) is None -# --- Tests for find_pooling.py --- +def test_find_distribution_with_keras_layer(): + layer = keras.layers.Dense(3) + result = find_distribution(layer) + assert result is layer -def dummy_pooling_constructor(*args, **kwargs): - return {"args": args, "kwargs": kwargs} +def test_find_distribution_mixture_raises(): + with pytest.raises(ValueError): + find_distribution("mixture") -def test_find_pooling_mean(): - pooling = find_pooling("mean") - # Check that a keras Lambda layer is returned - assert isinstance(pooling, keras.layers.Lambda) - # Test that the lambda function produces a mean when applied to a sample tensor. - - sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]]) - # Keras Lambda layers expect tensors via call(), here we simply call the layer's function. - result = pooling.call(sample) - assert_allclose(result, keras.ops.mean(sample, axis=-2)) - - -@pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)]) -def test_find_pooling_max_min(name, func): - pooling = find_pooling(name) - assert isinstance(pooling, keras.layers.Lambda) - - sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]]) - result = pooling.call(sample) - assert_allclose(result, func(sample, axis=-2)) - - -def test_find_pooling_learnable(monkeypatch): - # Monkey patch the PoolingByMultiHeadAttention in its module - class DummyPoolingAttention: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - - monkeypatch.setattr("bayesflow.networks.transformers.pma.PoolingByMultiHeadAttention", DummyPoolingAttention) - pooling = find_pooling("learnable", 7, option="test") - assert isinstance(pooling, DummyPoolingAttention) - assert pooling.args == (7,) - assert pooling.kwargs == {"option": "test"} +def test_find_distribution_invalid_name(): + with pytest.raises(ValueError): + find_distribution("invalid_name") -def test_find_pooling_with_constructor(): - # Passing a type should result in an instance. - class DummyPooling: - def __init__(self, data): - self.data = data - pooling = find_pooling(DummyPooling, "dummy") - assert isinstance(pooling, DummyPooling) - assert pooling.data == "dummy" +def test_find_distribution_invalid_type(): + with pytest.raises(TypeError): + find_distribution(3.14) -def test_find_pooling_with_keras_layer(): - layer = keras.layers.ReLU() - pooling = find_pooling(layer) - assert pooling is layer +# --- Tests for find_summary_network.py --- -def test_find_pooling_invalid_type(): - with pytest.raises(TypeError): - find_pooling(123) +class DummySummaryNetwork: + def __init__(self, *a, **kw): + self.args = a + self.kwargs = kw -# --- Tests for find_recurrent_net.py --- +@pytest.mark.parametrize( + "name,expected_class_path", + [ + ("deep_set", "bayesflow.networks.DeepSet"), + ("set_transformer", "bayesflow.networks.SetTransformer"), + ("fusion_transformer", "bayesflow.networks.FusionTransformer"), + ("time_series_transformer", "bayesflow.networks.TimeSeriesTransformer"), + ("time_series_network", "bayesflow.networks.TimeSeriesNetwork"), + ], +) +def test_find_summary_network_by_name(monkeypatch, name, expected_class_path): + components = expected_class_path.split(".") + module_path = ".".join(components[:-1]) + class_name = components[-1] + dummy_cls = DummySummaryNetwork + monkeypatch.setattr(f"{module_path}.{class_name}", dummy_cls) -def test_find_recurrent_net_lstm(): - constructor = find_recurrent_net("lstm") - assert constructor is keras.layers.LSTM + net = find_summary_network(name, 22, flag=True) + assert isinstance(net, DummySummaryNetwork) + assert net.args == (22,) + assert net.kwargs == {"flag": True} -def test_find_recurrent_net_gru(): - constructor = find_recurrent_net("gru") - assert constructor is keras.layers.GRU +def test_find_summary_network_by_keras_layer(): + layer = keras.layers.Dense(1) + out = find_summary_network(layer) + assert out is layer -def test_find_recurrent_net_with_keras_layer(): - layer = keras.layers.SimpleRNN(5) - net = find_recurrent_net(layer) - assert net is layer +def test_find_summary_network_by_keras_model(): + model = keras.models.Sequential() + out = find_summary_network(model) + assert out is model -def test_find_recurrent_net_invalid_name(): +def test_find_summary_network_unknown_name(): with pytest.raises(ValueError): - find_recurrent_net("invalid_net") + find_summary_network("unknown_summary_net") -def test_find_recurrent_net_invalid_type(): +def test_find_summary_network_invalid_type(): with pytest.raises(TypeError): - find_recurrent_net(3.1415) + find_summary_network(0.1234) From acf1c722e536c6b99c893b5e65d026076b8eb531 Mon Sep 17 00:00:00 2001 From: Lars Date: Fri, 25 Apr 2025 20:00:39 -0400 Subject: [PATCH 25/46] Update issue templates (#448) * Hotfix Version 2.0.1 (#431) * fix optimal transport config (#429) * run linter * [skip-ci] bump version to 2.0.1 * Update issue templates --- .github/ISSUE_TEMPLATE/bug_report.md | 36 +++++++++++++++++++++++ .github/ISSUE_TEMPLATE/feature_request.md | 20 +++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..a901605ee --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,36 @@ +--- +name: Bug report +about: Create a bug report to help us improve BayesFlow +title: "[BUG]" +labels: '' +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Minimal steps to reproduce the behavior: +1. Import '...' +2. Create network '....' +3. Call '....' +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Traceback** +If you encounter an error, please provide a complete traceback to help explain your problem. + +**Environment** +- OS: [e.g. Ubuntu] +- Python Version: [e.g. 3.11] +- Backend: [e.g. jax, tensorflow, pytorch] +- BayesFlow Version: [e.g. 2.0.2] + +**Additional context** +Add any other context about the problem here. + +**Minimality** +- [ ] I verify that my example is minimal, does not rely on third-party packages, and is most likely an issue in BayesFlow. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..da5db4b74 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest a new feature to be implemented in BayesFlow +title: "[FEATURE]" +labels: feature +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. From d24f5a3dc0899366fd73419cc8fd6c89f7f36acb Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 25 Apr 2025 20:03:22 -0400 Subject: [PATCH 26/46] Robustify kwargs passing inference networks, add class variables --- bayesflow/approximators/approximator.py | 2 +- .../approximators/continuous_approximator.py | 20 +++++++-- .../model_comparison_approximator.py | 12 +++-- bayesflow/approximators/point_approximator.py | 4 +- .../consistency_models/consistency_model.py | 44 ++++++++----------- .../networks/coupling_flow/coupling_flow.py | 4 +- 6 files changed, 51 insertions(+), 35 deletions(-) diff --git a/bayesflow/approximators/approximator.py b/bayesflow/approximators/approximator.py index e09751b3d..825e93d32 100644 --- a/bayesflow/approximators/approximator.py +++ b/bayesflow/approximators/approximator.py @@ -23,7 +23,7 @@ def build_adapter(cls, **kwargs) -> Adapter: raise NotImplementedError def build_from_data(self, data: Mapping[str, any]) -> None: - self.compute_metrics(**data, stage="training") + self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training") self.built = True @classmethod diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index f0c1d68fa..dcb661ca0 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -32,6 +32,8 @@ class ContinuousApproximator(Approximator): Additional arguments passed to the :py:class:`bayesflow.approximators.Approximator` class. """ + SAMPLE_KEYS = ["summary_variables", "inference_conditions"] + def __init__( self, *, @@ -51,6 +53,7 @@ def build_adapter( inference_variables: Sequence[str], inference_conditions: Sequence[str] = None, summary_variables: Sequence[str] = None, + standardize: bool = True, sample_weight: str = None, ) -> Adapter: """Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator. @@ -63,9 +66,12 @@ def build_adapter( Names of the inference conditions in the data summary_variables : Sequence of str, optional Names of the summary variables in the data + standardize : bool, optional + Decide whether to standardize all variables, default is True sample_weight : str, optional Name of the sample weights """ + adapter = Adapter() adapter.to_array() adapter.convert_dtype("float64", "float32") @@ -82,7 +88,9 @@ def build_adapter( adapter = adapter.rename(sample_weight, "sample_weight") adapter.keep(["inference_variables", "inference_conditions", "summary_variables", "sample_weight"]) - adapter.standardize(exclude="sample_weight") + + if standardize: + adapter.standardize(exclude="sample_weight") return adapter @@ -334,12 +342,18 @@ def sample( dict[str, np.ndarray] Dictionary containing generated samples with the same keys as `conditions`. """ + + # Apply adapter transforms to raw simulated / real quantities conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) - # at inference time, inference_variables are estimated by the networks and thus ignored in conditions - conditions.pop("inference_variables", None) + + # Ensure only keys relevant for sampling are present in the conditions dictionary + conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.SAMPLE_KEYS} + conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions) conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)} conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions) + + # Back-transform quantities and samples conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs) if split: diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 03b377537..94e8ebc63 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -30,11 +30,13 @@ class ModelComparisonApproximator(Approximator): The network backbone (e.g, an MLP) that is used for model classification. The input of the classifier network is created by concatenating `classifier_variables` and (optional) output of the summary_network. - summary_network: bg.networks.SummaryNetwork, optional + summary_network: bf.networks.SummaryNetwork, optional The summary network used for data summarization (default is None). The input of the summary network is `summary_variables`. """ + SAMPLE_KEYS = ["summary_variables", "inference_conditions"] + def __init__( self, *, @@ -304,9 +306,13 @@ def predict( np.ndarray Predicted posterior model probabilities given `conditions`. """ + + # Apply adapter transforms to raw simulated / real quantities conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) - # at inference time, model_indices are predicted by the networks and thus ignored in conditions - conditions.pop("model_indices", None) + + # Ensure only keys relevant for sampling are present in the conditions dictionary + conditions = {k: v for k, v in conditions.items() if k in ModelComparisonApproximator.SAMPLE_KEYS} + conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions) output = self._predict(**conditions, **kwargs) diff --git a/bayesflow/approximators/point_approximator.py b/bayesflow/approximators/point_approximator.py index 457b23138..1e407e2a6 100644 --- a/bayesflow/approximators/point_approximator.py +++ b/bayesflow/approximators/point_approximator.py @@ -156,8 +156,10 @@ def log_prob( def _prepare_conditions(self, conditions: Mapping[str, np.ndarray], **kwargs) -> dict[str, Tensor]: """Adapts and converts the conditions to tensors.""" + conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) - conditions.pop("inference_variables", None) + conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.SAMPLE_KEYS} + return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions) def _apply_inverse_adapter_to_estimates( diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 3bcd79a0d..b8d4c56ed 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -187,7 +187,7 @@ def build(self, xz_shape, conditions_shape=None): self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1]) self.c_huber2 = self.c_huber**2 - ## Calculate discretization schedule in advance + # Calculate discretization schedule in advance # The Jax compiler requires fixed-size arrays, so we have # to store all the discretized_times in one matrix in advance # and later only access the relevant entries. @@ -213,34 +213,24 @@ def build(self, xz_shape, conditions_shape=None): disc = ops.convert_to_numpy(self._discretize_time(n)) discretized_times[i, : len(disc)] = disc discretization_map[n] = i + # Finally, we convert the vectors to tensors self.discretized_times = ops.convert_to_tensor(discretized_times, dtype="float32") self.discretization_map = ops.convert_to_tensor(discretization_map) - def call( - self, - xz: Tensor, - conditions: Tensor = None, - inverse: bool = False, - **kwargs, - ): - if inverse: - return self._inverse(xz, conditions=conditions, **kwargs) - return self._forward(xz, conditions=conditions, **kwargs) - - def _forward_train(self, x: Tensor, noise: Tensor, t: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: - """Forward function for training. Calls consistency function with - noisy input - """ + def _forward_train( + self, x: Tensor, noise: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False, **kwargs + ) -> Tensor: + """Forward function for training. Calls consistency function with noisy input""" inp = x + t * noise - return self.consistency_function(inp, t, conditions=conditions, **kwargs) + return self.consistency_function(inp, t, conditions=conditions, training=training) def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: # Consistency Models only learn the direction from noise distribution # to target distribution, so we cannot implement this function. raise NotImplementedError("Consistency Models are not invertible") - def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: + def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: """Generate random draws from the approximate target distribution using the multistep sampling algorithm from [1], Algorithm 1. @@ -249,7 +239,9 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: z : Tensor Samples from a standard normal distribution conditions : Tensor, optional, default: None - Conditions for a approximate conditional distribution + Conditions for the approximate conditional distribution + training : bool, optional, default: True + Whether internal layers (e.g., dropout) should behave in train or inference mode. **kwargs : dict, optional, default: {} Additional keyword arguments. Include `steps` (default: 10) to adjust the number of sampling steps. @@ -263,15 +255,17 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: x = keras.ops.copy(z) * self.max_time discretized_time = keras.ops.flip(self._discretize_time(steps), axis=-1) t = keras.ops.full((*keras.ops.shape(x)[:-1], 1), discretized_time[0], dtype=x.dtype) - x = self.consistency_function(x, t, conditions=conditions) + + x = self.consistency_function(x, t, conditions=conditions, training=training) + for n in range(1, steps): noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator) x_n = x + keras.ops.sqrt(keras.ops.square(discretized_time[n]) - self.eps**2) * noise t = keras.ops.full_like(t, discretized_time[n]) - x = self.consistency_function(x_n, t, conditions=conditions) + x = self.consistency_function(x_n, t, conditions=conditions, training=training) return x - def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: + def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor: """Compute consistency function. Parameters @@ -282,8 +276,8 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, Vector of time samples in [eps, T] conditions : Tensor The conditioning vector - **kwargs : dict, optional, default: {} - Additional keyword arguments passed to the network. + training : bool, optional, default: True + Whether internal layers (e.g., dropout) should behave in train or inference mode. """ if conditions is not None: @@ -291,7 +285,7 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, else: xtc = ops.concatenate([x, t], axis=-1) - f = self.output_projector(self.subnet(xtc, **kwargs)) + f = self.output_projector(self.subnet(xtc, training=training)) # Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim) # Thus, we can do a cross product with the time vector which is (batch_size, 1) for diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index ee78f180e..203962b0f 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -152,7 +152,7 @@ def _forward( z = x log_det = keras.ops.zeros(keras.ops.shape(x)[:-1]) for layer in self.invertible_layers: - z, det = layer(z, conditions=conditions, inverse=False, training=training, **kwargs) + z, det = layer(z, conditions=conditions, inverse=False, training=training) log_det += det if density: @@ -168,7 +168,7 @@ def _inverse( x = z log_det = keras.ops.zeros(keras.ops.shape(z)[:-1]) for layer in reversed(self.invertible_layers): - x, det = layer(x, conditions=conditions, inverse=True, training=training, **kwargs) + x, det = layer(x, conditions=conditions, inverse=True, training=training) log_det += det if density: From 463c0c7e7f8c5f3f4a990e9b6bc002dd9b6c1130 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 20:13:41 -0400 Subject: [PATCH 27/46] fix convergence method to debug for non-log sinkhorn --- bayesflow/utils/optimal_transport/sinkhorn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/utils/optimal_transport/sinkhorn.py b/bayesflow/utils/optimal_transport/sinkhorn.py index 04c268eb0..f7e0ba835 100644 --- a/bayesflow/utils/optimal_transport/sinkhorn.py +++ b/bayesflow/utils/optimal_transport/sinkhorn.py @@ -104,7 +104,7 @@ def do_nothing(): def log_steps(): msg = "Sinkhorn-Knopp converged after {} steps." - logging.info(msg, max_steps) + logging.debug(msg, max_steps) def warn_convergence(): msg = "Sinkhorn-Knopp did not converge after {}." From 8f3739c6c0d5731030228ddd5b4d0021295f7257 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 25 Apr 2025 20:18:30 -0400 Subject: [PATCH 28/46] Bump optimal transport default to False --- bayesflow/networks/flow_matching/flow_matching.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 7a097d340..3c0190467 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -55,7 +55,7 @@ def __init__( self, subnet: str | keras.Layer = "mlp", base_distribution: str | Distribution = "normal", - use_optimal_transport: bool = True, + use_optimal_transport: bool = False, loss_fn: str | keras.Loss = "mse", integrate_kwargs: dict[str, any] = None, optimal_transport_kwargs: dict[str, any] = None, @@ -82,7 +82,8 @@ def __init__( The base probability distribution from which samples are drawn, such as "normal". Default is "normal". use_optimal_transport : bool, optional - Whether to apply optimal transport for improved training stability. Default is True. + Whether to apply optimal transport for improved training stability. Default is False. + Note: this will increase training time by approximately ~2.5 times, but may lead to faster inference. loss_fn : str, optional The loss function used for training, such as "mse". Default is "mse". integrate_kwargs : dict[str, any], optional From 40eccd4ed678ce49710ccff3f9239a23e2dedae4 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 22:42:01 -0400 Subject: [PATCH 29/46] use logging.info for backend selection instead of logging.debug --- bayesflow/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/__init__.py b/bayesflow/__init__.py index 7a358341e..5a28ffe2e 100644 --- a/bayesflow/__init__.py +++ b/bayesflow/__init__.py @@ -33,7 +33,7 @@ def setup(): from bayesflow.utils import logging - logging.debug(f"Using backend {keras.backend.backend()!r}") + logging.info(f"Using backend {keras.backend.backend()!r}") if keras.backend.backend() == "torch": import torch From 8903089082979501eb781831267e181b1c36679a Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 22:42:09 -0400 Subject: [PATCH 30/46] fix model comparison approximator --- bayesflow/approximators/model_comparison_approximator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 94e8ebc63..1b9d198ff 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -35,7 +35,7 @@ class ModelComparisonApproximator(Approximator): The input of the summary network is `summary_variables`. """ - SAMPLE_KEYS = ["summary_variables", "inference_conditions"] + SAMPLE_KEYS = ["summary_variables", "classifier_conditions"] def __init__( self, From cbc86b8e8f2a5ffc11aa274965feaa6f8aee5555 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 22:42:16 -0400 Subject: [PATCH 31/46] improve docs and type hints --- bayesflow/simulators/lambda_simulator.py | 6 +++--- bayesflow/simulators/model_comparison_simulator.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/bayesflow/simulators/lambda_simulator.py b/bayesflow/simulators/lambda_simulator.py index c6baa7edb..aadefea6e 100644 --- a/bayesflow/simulators/lambda_simulator.py +++ b/bayesflow/simulators/lambda_simulator.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, Sequence, Mapping +from collections.abc import Callable, Sequence import numpy as np @@ -12,13 +12,13 @@ class LambdaSimulator(Simulator): """Implements a simulator based on a sampling function.""" - def __init__(self, sample_fn: Callable[[Sequence[int]], Mapping[str, any]], *, is_batched: bool = False): + def __init__(self, sample_fn: Callable[[Sequence[int]], dict[str, any]], *, is_batched: bool = False): """ Initialize a simulator based on a simple callable function Parameters ---------- - sample_fn : Callable[[Sequence[int]], Mapping[str, any]] + sample_fn : Callable[[Sequence[int]], dict[str, any]] A function that generates samples. It should accept `batch_shape` as its first argument (if `is_batched=True`), followed by keyword arguments. is_batched : bool, optional diff --git a/bayesflow/simulators/model_comparison_simulator.py b/bayesflow/simulators/model_comparison_simulator.py index e05f4bbc4..60174ef92 100644 --- a/bayesflow/simulators/model_comparison_simulator.py +++ b/bayesflow/simulators/model_comparison_simulator.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence import numpy as np from bayesflow.types import Shape @@ -22,10 +22,10 @@ def __init__( p: Sequence[float] = None, logits: Sequence[float] = None, use_mixed_batches: bool = True, - shared_simulator: Simulator | FunctionType = None, + shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None, ): """ - Initialize a multi-model simulator that can generate data for mixture / model comparison problems. + Initialize a multimodel simulator that can generate data for mixture / model comparison problems. Parameters ---------- @@ -40,7 +40,7 @@ def __init__( use_mixed_batches : bool, optional If True, samples in a batch are drawn from different models. If False, the entire batch is drawn from a single model chosen according to the model probabilities. Default is True. - shared_simulator : Simulator or FunctionType, optional + shared_simulator : Simulator or Callable, optional A shared simulator whose outputs are passed to all model simulators. If a function is provided, it is wrapped in a `LambdaSimulator` with batching enabled. """ From 77ddc5ac9f176386831a279d9f4df96d70ae9fef Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 22:44:13 -0400 Subject: [PATCH 32/46] improve One-Sample T-Test Notebook: - use torch as default backend - reduce range of N so users of jax won't be stuck with a slow notebook - use BayesFlow built-in MLP instead of keras.Sequential solution - general code cleanup --- examples/One_Sample_TTest.ipynb | 325 +++++++++++++++++++------------- 1 file changed, 193 insertions(+), 132 deletions(-) diff --git a/examples/One_Sample_TTest.ipynb b/examples/One_Sample_TTest.ipynb index 73b0ba6db..d75dcff53 100644 --- a/examples/One_Sample_TTest.ipynb +++ b/examples/One_Sample_TTest.ipynb @@ -22,13 +22,29 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:34:13.774900Z", + "start_time": "2025-04-26T02:34:11.487313Z" + } + }, + "source": [ + "import numpy as np\n", + "\n", + "import os\n", + "if \"KERAS_BACKEND\" not in os.environ:\n", + " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", + " os.environ[\"KERAS_BACKEND\"] = \"torch\"\n", + "\n", + "import keras\n", + "import bayesflow as bf" + ], "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ + "INFO:bayesflow:Using backend 'torch'\n", "WARNING:bayesflow:\n", "When using torch backend, we need to disable autograd by default to avoid excessive memory usage. Use\n", "\n", @@ -39,17 +55,7 @@ ] } ], - "source": [ - "import numpy as np\n", - "\n", - "import os\n", - "if \"KERAS_BACKEND\" not in os.environ:\n", - " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", - " os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", - "\n", - "import keras\n", - "import bayesflow as bf" - ] + "execution_count": 1 }, { "cell_type": "markdown", @@ -76,21 +82,24 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:34:13.780691Z", + "start_time": "2025-04-26T02:34:13.777756Z" + } + }, "source": [ "def context(batch_shape, n=None):\n", " if n is None:\n", - " n = np.random.randint(5, 50)\n", - " return dict(n = n)\n", + " n = np.random.randint(20, 30)\n", + " return dict(n=n)\n", "\n", "def prior_null():\n", - " return dict(mu = 0.0)\n", + " return dict(mu=0.0)\n", "\n", "def prior_alternative():\n", " mu = np.random.normal(loc=0, scale=1)\n", - " return dict(mu = mu)\n", + " return dict(mu=mu)\n", "\n", "def likelihood(n, mu):\n", " x = np.random.normal(loc=mu, scale=1, size=n)\n", @@ -101,8 +110,11 @@ "simulator = bf.simulators.ModelComparisonSimulator(\n", " simulators=[simulator_null, simulator_alternative], \n", " use_mixed_batches=True, \n", - " shared_simulator=context)" - ] + " shared_simulator=context,\n", + ")" + ], + "outputs": [], + "execution_count": 2 }, { "cell_type": "markdown", @@ -113,27 +125,32 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:34:13.840198Z", + "start_time": "2025-04-26T02:34:13.837161Z" + } + }, + "source": [ + "data = simulator.sample(100)\n", + "print(\"n =\", data[\"n\"])\n", + "for key, value in data.items():\n", + " print(key + \" shape:\", np.array(value).shape)" + ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "n = 37\n", + "n = 20\n", "n shape: ()\n", "mu shape: (100, 1)\n", - "x shape: (100, 37)\n", + "x shape: (100, 20)\n", "model_indices shape: (100, 2)\n" ] } ], - "source": [ - "data = simulator.sample(100)\n", - "print(\"n =\", data[\"n\"])\n", - "for key, value in data.items():\n", - " print(key + \" shape:\", np.array(value).shape)" - ] + "execution_count": 3 }, { "cell_type": "markdown", @@ -155,9 +172,12 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:34:13.884625Z", + "start_time": "2025-04-26T02:34:13.882973Z" + } + }, "source": [ "adapter = (\n", " bf.Adapter()\n", @@ -166,10 +186,12 @@ " .as_set(\"x\")\n", " .rename(\"n\", \"classifier_conditions\")\n", " .rename(\"x\", \"summary_variables\")\n", - " .drop('mu')\n", + " .drop(\"mu\")\n", " .convert_dtype(\"float64\", \"float32\")\n", " )" - ] + ], + "outputs": [], + "execution_count": 4 }, { "cell_type": "markdown", @@ -188,8 +210,17 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:34:13.930048Z", + "start_time": "2025-04-26T02:34:13.928375Z" + } + }, + "source": [ + "processed_data=adapter(data)\n", + "for key, value in processed_data.items():\n", + " print(key + \" shape:\", value.shape)" + ], "outputs": [ { "name": "stdout", @@ -197,15 +228,11 @@ "text": [ "model_indices shape: (100, 2)\n", "classifier_conditions shape: (100, 1)\n", - "summary_variables shape: (100, 37, 1)\n" + "summary_variables shape: (100, 20, 1)\n" ] } ], - "source": [ - "processed_data=adapter(data)\n", - "for key, value in processed_data.items():\n", - " print(key + \" shape:\", value.shape)" - ] + "execution_count": 5 }, { "cell_type": "markdown", @@ -231,15 +258,18 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:34:13.996060Z", + "start_time": "2025-04-26T02:34:13.974940Z" + } + }, "source": [ - "summary_network = bf.networks.DeepSet(summary_dim=4, dropout=0.0)\n", - "classifier_network = keras.Sequential(\n", - " [keras.layers.Dense(32, activation=\"silu\") for _ in range(4)]\n", - ")" - ] + "summary_network = bf.networks.DeepSet(summary_dim=8, dropout=None)\n", + "classifier_network = bf.networks.MLP(widths=[32] * 4, activation=\"silu\", dropout=None)" + ], + "outputs": [], + "execution_count": 6 }, { "cell_type": "markdown", @@ -250,17 +280,22 @@ }, { "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:34:14.021809Z", + "start_time": "2025-04-26T02:34:14.019827Z" + } + }, "source": [ "approximator = bf.approximators.ModelComparisonApproximator(\n", - " num_models=2, \n", + " num_models=2,\n", " classifier_network=classifier_network,\n", - " summary_network=summary_network, \n", - " adapter=adapter\n", + " summary_network=summary_network,\n", + " adapter=adapter,\n", ")" - ] + ], + "outputs": [], + "execution_count": 7 }, { "cell_type": "markdown", @@ -275,14 +310,19 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:34:14.065229Z", + "start_time": "2025-04-26T02:34:14.063948Z" + } + }, "source": [ - "num_batches = 64\n", + "num_batches_per_epoch = 64\n", "batch_size = 512\n", - "epochs = 20" - ] + "epochs = 32" + ], + "outputs": [], + "execution_count": 8 }, { "cell_type": "markdown", @@ -293,14 +333,19 @@ }, { "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:34:14.118477Z", + "start_time": "2025-04-26T02:34:14.109925Z" + } + }, "source": [ - "learning_rate = keras.optimizers.schedules.CosineDecay(1e-4, decay_steps=epochs*num_batches, alpha=1e-5)\n", - "optimizer = keras.optimizers.Adam(learning_rate=learning_rate, clipnorm=1.0)\n", + "learning_rate = keras.optimizers.schedules.CosineDecay(1e-4, decay_steps=epochs * num_batches_per_epoch)\n", + "optimizer = keras.optimizers.Adam(learning_rate=learning_rate)\n", "approximator.compile(optimizer=optimizer)" - ] + ], + "outputs": [], + "execution_count": 9 }, { "cell_type": "markdown", @@ -311,18 +356,18 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "history = approximator.fit(\n", " epochs=epochs,\n", - " num_batches=num_batches,\n", + " num_batches=num_batches_per_epoch,\n", " batch_size=batch_size,\n", " simulator=simulator,\n", - " adapter=adapter\n", + " adapter=adapter,\n", ")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -333,23 +378,26 @@ }, { "cell_type": "code", - "execution_count": 18, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:35:17.051898Z", + "start_time": "2025-04-26T02:35:16.960475Z" + } + }, + "source": "f = bf.diagnostics.plots.loss(history=history)", "outputs": [ { "data": { - "image/png": "", "text/plain": [ "
" - ] + ], + "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], - "source": [ - "f = bf.diagnostics.plots.loss(history=history)" - ] + "execution_count": 11 }, { "cell_type": "markdown", @@ -364,51 +412,66 @@ }, { "cell_type": "code", - "execution_count": 19, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:35:17.556053Z", + "start_time": "2025-04-26T02:35:17.058765Z" + } + }, + "source": [ + "df = simulator.sample(5000, n=10)\n", + "print(f\"{df['n']=}\")\n", + "print(f\"{df['x'].shape=}\")" + ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10\n", - "(5000, 10)\n" + "df['n']=10\n", + "df['x'].shape=(30000, 10)\n" ] } ], - "source": [ - "df=simulator.sample(5000, n=10)\n", - "print(df[\"n\"])\n", - "print(df[\"x\"].shape)" - ] + "execution_count": 12 }, { - "cell_type": "markdown", "metadata": {}, - "source": [ - "To apply our approximator on this dataset, we simply use the `.predict` method to obtain the predicted posterior model probabilities, given the data and the approximator." - ] + "cell_type": "markdown", + "source": "To apply our approximator on this dataset, we simply use the `.predict` method to obtain the predicted posterior model probabilities, given the data and the approximator." }, { + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:35:17.599619Z", + "start_time": "2025-04-26T02:35:17.563524Z" + } + }, "cell_type": "code", - "execution_count": 20, - "metadata": {}, + "source": "pred_models = approximator.predict(conditions=df)", "outputs": [], - "source": [ - "pred_models=approximator.predict(conditions=df)" - ] + "execution_count": 13 }, { - "cell_type": "markdown", "metadata": {}, - "source": [ - "We inspect the model comparison calibration now." - ] + "cell_type": "markdown", + "source": "We inspect the model comparison calibration now." }, { + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:35:17.772515Z", + "start_time": "2025-04-26T02:35:17.612303Z" + } + }, "cell_type": "code", - "execution_count": 21, - "metadata": {}, + "source": [ + "f=bf.diagnostics.plots.mc_calibration(\n", + " pred_models=pred_models,\n", + " true_models=df[\"model_indices\"],\n", + " model_names=[r\"$\\mathcal{M}_0$\",r\"$\\mathcal{M}_1$\"],\n", + ")" + ], "outputs": [ { "name": "stderr", @@ -422,33 +485,38 @@ }, { "data": { - "image/png": "", "text/plain": [ "
" - ] + ], + "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], - "source": [ - "f=bf.diagnostics.plots.mc_calibration(\n", - " pred_models=pred_models, \n", - " true_models=df[\"model_indices\"],\n", - " model_names=[r\"$\\mathcal{M}_0$\",r\"$\\mathcal{M}_1$\"])" - ] + "execution_count": 14 }, { - "cell_type": "markdown", "metadata": {}, - "source": [ - "And the confusion matrix to inspect how often we would make an accurate decision based on picking the model with the highest posterior probability." - ] + "cell_type": "markdown", + "source": "And the confusion matrix to inspect how often we would make an accurate decision based on picking the model with the highest posterior probability." }, { "cell_type": "code", - "execution_count": 22, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-26T02:35:17.851048Z", + "start_time": "2025-04-26T02:35:17.784250Z" + } + }, + "source": [ + "f=bf.diagnostics.plots.mc_confusion_matrix(\n", + " pred_models=pred_models,\n", + " true_models=df[\"model_indices\"],\n", + " model_names=[r\"$\\mathcal{M}_0$\",r\"$\\mathcal{M}_1$\"],\n", + " normalize=\"true\",\n", + ")" + ], "outputs": [ { "name": "stderr", @@ -462,23 +530,16 @@ }, { "data": { - "image/png": "", "text/plain": [ "
" - ] + ], + "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], - "source": [ - "f=bf.diagnostics.plots.mc_confusion_matrix(\n", - " pred_models=pred_models,\n", - " true_models=df['model_indices'], \n", - " model_names=[r\"$\\mathcal{M}_0$\",r\"$\\mathcal{M}_1$\"],\n", - " normalize=\"true\"\n", - ")" - ] + "execution_count": 15 } ], "metadata": { From ad011711b398ceb6650a3a16d7abdc73ea94dfe2 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 22:44:44 -0400 Subject: [PATCH 33/46] remove backend print --- examples/SIR_Posterior_Estimation.ipynb | 35 +++++++------------------ 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/examples/SIR_Posterior_Estimation.ipynb b/examples/SIR_Posterior_Estimation.ipynb index c7dafa37f..7963d00e5 100644 --- a/examples/SIR_Posterior_Estimation.ipynb +++ b/examples/SIR_Posterior_Estimation.ipynb @@ -11,40 +11,24 @@ ] }, { - "cell_type": "code", - "execution_count": 1, - "id": "0383ba66", "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "import os\n", "# Set to your favorite backend\n", "if \"KERAS_BACKEND\" not in os.environ:\n", " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", - " os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", - "else:\n", - " print(f\"Using '{os.environ['KERAS_BACKEND']}' backend\")" - ] + " os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"" + ], + "id": "5fb5c0f856b6bcf4" }, { + "metadata": {}, "cell_type": "code", - "execution_count": 2, - "id": "684f2d7e19d40e09", - "metadata": { - "ExecuteTime": { - "end_time": "2025-04-11T19:54:02.700953Z", - "start_time": "2025-04-11T19:53:33.926075Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:2025-04-21 12:41:48,425:jax._src.xla_bridge:967: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - } - ], + "outputs": [], + "execution_count": null, "source": [ "import datetime\n", "\n", @@ -55,7 +39,8 @@ "import keras\n", "\n", "import bayesflow as bf" - ] + ], + "id": "4a9355783f1314a" }, { "cell_type": "markdown", From a742d9c66254895f2e7ee14a2599ec737c871fda Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 22:45:17 -0400 Subject: [PATCH 34/46] [skip ci] turn all single-quoted strings into double-quoted strings --- examples/From_ABC_to_BayesFlow.ipynb | 36 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/From_ABC_to_BayesFlow.ipynb b/examples/From_ABC_to_BayesFlow.ipynb index b9757d9c4..35947a4cd 100644 --- a/examples/From_ABC_to_BayesFlow.ipynb +++ b/examples/From_ABC_to_BayesFlow.ipynb @@ -479,7 +479,7 @@ "\n", "# For BayesFlow devs: this ensures that the latest dev version can be found\n", "import sys\n", - "sys.path.append('../')\n", + "sys.path.append(\"../\")\n", "\n", "import bayesflow as bf" ] @@ -513,11 +513,11 @@ "source": [ "def prior_helper():\n", " \"\"\"The ABC prior returns a Parameter Object from pyabc which we convert to a dict.\"\"\"\n", - " return dict(rate=prior.rvs()['rate'])\n", + " return dict(rate=prior.rvs()[\"rate\"])\n", "\n", "def sim_helper(rate):\n", " \"\"\"The simulator returns a dict, we extract the output at the test times.\"\"\"\n", - " temp = sim({'rate': rate})\n", + " temp = sim({\"rate\": rate})\n", " xt_ind = np.searchsorted(temp[\"t\"], t_test_times) - 1\n", " obs = temp[\"X\"][:, 1][xt_ind]\n", " return dict(obs=obs)" @@ -568,8 +568,8 @@ ], "source": [ "adapter = bf.approximators.ContinuousApproximator.build_adapter(\n", - " inference_variables='rate',\n", - " inference_conditions='obs',\n", + " inference_variables=\"rate\",\n", + " inference_conditions=\"obs\",\n", " summary_variables=None\n", ")\n", "adapter" @@ -665,25 +665,25 @@ "output_type": "stream", "text": [ "Epoch 1/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 62ms/step - loss: 0.4428 - loss/inference_loss: 0.4428 - val_loss: 0.4605 - val_loss/inference_loss: 0.4605\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 62ms/step - loss: 0.4428 - loss/inference_loss: 0.4428 - val_loss: 0.4605 - val_loss/inference_loss: 0.4605\n", "Epoch 2/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 64ms/step - loss: 0.3700 - loss/inference_loss: 0.3700 - val_loss: 0.4467 - val_loss/inference_loss: 0.4467\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 64ms/step - loss: 0.3700 - loss/inference_loss: 0.3700 - val_loss: 0.4467 - val_loss/inference_loss: 0.4467\n", "Epoch 3/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 68ms/step - loss: 0.3458 - loss/inference_loss: 0.3458 - val_loss: 0.3627 - val_loss/inference_loss: 0.3627\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 68ms/step - loss: 0.3458 - loss/inference_loss: 0.3458 - val_loss: 0.3627 - val_loss/inference_loss: 0.3627\n", "Epoch 4/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 70ms/step - loss: 0.3771 - loss/inference_loss: 0.3771 - val_loss: 0.3637 - val_loss/inference_loss: 0.3637\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 70ms/step - loss: 0.3771 - loss/inference_loss: 0.3771 - val_loss: 0.3637 - val_loss/inference_loss: 0.3637\n", "Epoch 5/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 69ms/step - loss: 0.3729 - loss/inference_loss: 0.3729 - val_loss: 0.2138 - val_loss/inference_loss: 0.2138\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 69ms/step - loss: 0.3729 - loss/inference_loss: 0.3729 - val_loss: 0.2138 - val_loss/inference_loss: 0.2138\n", "Epoch 6/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 66ms/step - loss: 0.3567 - loss/inference_loss: 0.3567 - val_loss: 0.2888 - val_loss/inference_loss: 0.2888\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 66ms/step - loss: 0.3567 - loss/inference_loss: 0.3567 - val_loss: 0.2888 - val_loss/inference_loss: 0.2888\n", "Epoch 7/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 62ms/step - loss: 0.4077 - loss/inference_loss: 0.4077 - val_loss: 0.3235 - val_loss/inference_loss: 0.3235\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 62ms/step - loss: 0.4077 - loss/inference_loss: 0.4077 - val_loss: 0.3235 - val_loss/inference_loss: 0.3235\n", "Epoch 8/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 61ms/step - loss: 0.4124 - loss/inference_loss: 0.4124 - val_loss: 0.3256 - val_loss/inference_loss: 0.3256\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 61ms/step - loss: 0.4124 - loss/inference_loss: 0.4124 - val_loss: 0.3256 - val_loss/inference_loss: 0.3256\n", "Epoch 9/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 61ms/step - loss: 0.3960 - loss/inference_loss: 0.3960 - val_loss: 0.2767 - val_loss/inference_loss: 0.2767\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 61ms/step - loss: 0.3960 - loss/inference_loss: 0.3960 - val_loss: 0.2767 - val_loss/inference_loss: 0.2767\n", "Epoch 10/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 60ms/step - loss: 0.4217 - loss/inference_loss: 0.4217 - val_loss: 0.3482 - val_loss/inference_loss: 0.3482\n" + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 60ms/step - loss: 0.4217 - loss/inference_loss: 0.4217 - val_loss: 0.3482 - val_loss/inference_loss: 0.3482\n" ] } ], @@ -829,7 +829,7 @@ "obs = observations[\"X\"][:, 1][xt_ind]\n", "\n", "# Obtain 1000 posterior samples\n", - "samples = workflow.sample(conditions={'obs': [obs]}, num_samples=num_samples)" + "samples = workflow.sample(conditions={\"obs\": [obs]}, num_samples=num_samples)" ] }, { @@ -881,9 +881,9 @@ "source": [ "# abc gives us weighted samples, we resample them to get comparable samples\n", "df, w = abc_history.get_distribution()\n", - "abc_samples = weighted_statistics.resample(df['rate'].values, w, 1000)\n", + "abc_samples = weighted_statistics.resample(df[\"rate\"].values, w, 1000)\n", "\n", - "f = bf.diagnostics.plots.pairs_posterior({'rate': abc_samples}, targets=np.array([true_rate]))" + "f = bf.diagnostics.plots.pairs_posterior({\"rate\": abc_samples}, targets=np.array([true_rate]))" ] }, { From b450961cb60bc6aef370f2695e74eb1514da00b6 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 25 Apr 2025 22:45:17 -0400 Subject: [PATCH 35/46] turn all single-quoted strings into double-quoted strings amend to trigger workflow --- examples/From_ABC_to_BayesFlow.ipynb | 36 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/From_ABC_to_BayesFlow.ipynb b/examples/From_ABC_to_BayesFlow.ipynb index b9757d9c4..35947a4cd 100644 --- a/examples/From_ABC_to_BayesFlow.ipynb +++ b/examples/From_ABC_to_BayesFlow.ipynb @@ -479,7 +479,7 @@ "\n", "# For BayesFlow devs: this ensures that the latest dev version can be found\n", "import sys\n", - "sys.path.append('../')\n", + "sys.path.append(\"../\")\n", "\n", "import bayesflow as bf" ] @@ -513,11 +513,11 @@ "source": [ "def prior_helper():\n", " \"\"\"The ABC prior returns a Parameter Object from pyabc which we convert to a dict.\"\"\"\n", - " return dict(rate=prior.rvs()['rate'])\n", + " return dict(rate=prior.rvs()[\"rate\"])\n", "\n", "def sim_helper(rate):\n", " \"\"\"The simulator returns a dict, we extract the output at the test times.\"\"\"\n", - " temp = sim({'rate': rate})\n", + " temp = sim({\"rate\": rate})\n", " xt_ind = np.searchsorted(temp[\"t\"], t_test_times) - 1\n", " obs = temp[\"X\"][:, 1][xt_ind]\n", " return dict(obs=obs)" @@ -568,8 +568,8 @@ ], "source": [ "adapter = bf.approximators.ContinuousApproximator.build_adapter(\n", - " inference_variables='rate',\n", - " inference_conditions='obs',\n", + " inference_variables=\"rate\",\n", + " inference_conditions=\"obs\",\n", " summary_variables=None\n", ")\n", "adapter" @@ -665,25 +665,25 @@ "output_type": "stream", "text": [ "Epoch 1/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 62ms/step - loss: 0.4428 - loss/inference_loss: 0.4428 - val_loss: 0.4605 - val_loss/inference_loss: 0.4605\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 62ms/step - loss: 0.4428 - loss/inference_loss: 0.4428 - val_loss: 0.4605 - val_loss/inference_loss: 0.4605\n", "Epoch 2/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 64ms/step - loss: 0.3700 - loss/inference_loss: 0.3700 - val_loss: 0.4467 - val_loss/inference_loss: 0.4467\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 64ms/step - loss: 0.3700 - loss/inference_loss: 0.3700 - val_loss: 0.4467 - val_loss/inference_loss: 0.4467\n", "Epoch 3/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 68ms/step - loss: 0.3458 - loss/inference_loss: 0.3458 - val_loss: 0.3627 - val_loss/inference_loss: 0.3627\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 68ms/step - loss: 0.3458 - loss/inference_loss: 0.3458 - val_loss: 0.3627 - val_loss/inference_loss: 0.3627\n", "Epoch 4/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 70ms/step - loss: 0.3771 - loss/inference_loss: 0.3771 - val_loss: 0.3637 - val_loss/inference_loss: 0.3637\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 70ms/step - loss: 0.3771 - loss/inference_loss: 0.3771 - val_loss: 0.3637 - val_loss/inference_loss: 0.3637\n", "Epoch 5/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 69ms/step - loss: 0.3729 - loss/inference_loss: 0.3729 - val_loss: 0.2138 - val_loss/inference_loss: 0.2138\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 69ms/step - loss: 0.3729 - loss/inference_loss: 0.3729 - val_loss: 0.2138 - val_loss/inference_loss: 0.2138\n", "Epoch 6/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 66ms/step - loss: 0.3567 - loss/inference_loss: 0.3567 - val_loss: 0.2888 - val_loss/inference_loss: 0.2888\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 66ms/step - loss: 0.3567 - loss/inference_loss: 0.3567 - val_loss: 0.2888 - val_loss/inference_loss: 0.2888\n", "Epoch 7/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 62ms/step - loss: 0.4077 - loss/inference_loss: 0.4077 - val_loss: 0.3235 - val_loss/inference_loss: 0.3235\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 62ms/step - loss: 0.4077 - loss/inference_loss: 0.4077 - val_loss: 0.3235 - val_loss/inference_loss: 0.3235\n", "Epoch 8/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 61ms/step - loss: 0.4124 - loss/inference_loss: 0.4124 - val_loss: 0.3256 - val_loss/inference_loss: 0.3256\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 61ms/step - loss: 0.4124 - loss/inference_loss: 0.4124 - val_loss: 0.3256 - val_loss/inference_loss: 0.3256\n", "Epoch 9/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 61ms/step - loss: 0.3960 - loss/inference_loss: 0.3960 - val_loss: 0.2767 - val_loss/inference_loss: 0.2767\n", + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 61ms/step - loss: 0.3960 - loss/inference_loss: 0.3960 - val_loss: 0.2767 - val_loss/inference_loss: 0.2767\n", "Epoch 10/10\n", - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 60ms/step - loss: 0.4217 - loss/inference_loss: 0.4217 - val_loss: 0.3482 - val_loss/inference_loss: 0.3482\n" + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 60ms/step - loss: 0.4217 - loss/inference_loss: 0.4217 - val_loss: 0.3482 - val_loss/inference_loss: 0.3482\n" ] } ], @@ -829,7 +829,7 @@ "obs = observations[\"X\"][:, 1][xt_ind]\n", "\n", "# Obtain 1000 posterior samples\n", - "samples = workflow.sample(conditions={'obs': [obs]}, num_samples=num_samples)" + "samples = workflow.sample(conditions={\"obs\": [obs]}, num_samples=num_samples)" ] }, { @@ -881,9 +881,9 @@ "source": [ "# abc gives us weighted samples, we resample them to get comparable samples\n", "df, w = abc_history.get_distribution()\n", - "abc_samples = weighted_statistics.resample(df['rate'].values, w, 1000)\n", + "abc_samples = weighted_statistics.resample(df[\"rate\"].values, w, 1000)\n", "\n", - "f = bf.diagnostics.plots.pairs_posterior({'rate': abc_samples}, targets=np.array([true_rate]))" + "f = bf.diagnostics.plots.pairs_posterior({\"rate\": abc_samples}, targets=np.array([true_rate]))" ] }, { From 6fa75da583fd094bf3ec00b9d1526683f6b63310 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 27 Apr 2025 13:20:49 +0000 Subject: [PATCH 36/46] [no ci] website: list example notebooks only on Examples page * one place less to update * fix one link to a notebook --- docsrc/source/index.md | 10 +--------- examples/SIR_Posterior_Estimation.ipynb | 2 +- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/docsrc/source/index.md b/docsrc/source/index.md index f89c5ff3f..f0cc6c159 100644 --- a/docsrc/source/index.md +++ b/docsrc/source/index.md @@ -45,15 +45,7 @@ history = workflow.fit_online(epochs=50, batch_size=32, num_batches_per_epoch=50 diagnostics = workflow.plot_default_diagnostics(test_data=300) ``` -For an in-depth exposition, check out our walkthrough notebooks below. - -1. [Linear regression starter example](_examples/Linear_Regression_Starter.ipynb) -2. [From ABC to BayesFlow](_examples/From_ABC_to_BayesFlow.ipynb) -3. [Two moons starter example](_examples/Two_Moons_Starter.ipynb) -4. [Rapid iteration with point estimators](_examples/Lotka_Volterra_point_estimation_and_expert_stats.ipynb) -5. [SIR model with custom summary network](_examples/SIR_Posterior_Estimation.ipynb) -6. [Bayesian experimental design](_examples/Bayesian_Experimental_Design.ipynb) -7. [Simple model comparison example](_examples/One_Sample_TTest.ipynb) +For an in-depth exposition, check out our walkthrough notebooks in the {doc}`Examples <../examples>` section. More tutorials are always welcome! Please consider making a pull request if you have a cool application that you want to contribute. diff --git a/examples/SIR_Posterior_Estimation.ipynb b/examples/SIR_Posterior_Estimation.ipynb index 7963d00e5..2e3f94bd9 100644 --- a/examples/SIR_Posterior_Estimation.ipynb +++ b/examples/SIR_Posterior_Estimation.ipynb @@ -84,7 +84,7 @@ "id": "39846c15b88eaf8e", "metadata": {}, "source": [ - "As described in our [very first notebook](linear_Regression_Starter.ipynb), a generative model consists of a prior (encoding suitable parameter ranges) and a simulator (generating data given simulations). Our underlying model distinguishes between susceptible, $S$, infected, $I$, and recovered, $R$, individuals with infection and recovery occurring at a constant transmission rate $\\lambda$ and constant recovery rate $\\mu$, respectively. The model dynamics are governed by the following system of ODEs:\n", + "As described in our [very first notebook](Linear_Regression_Starter.ipynb), a generative model consists of a prior (encoding suitable parameter ranges) and a simulator (generating data given simulations). Our underlying model distinguishes between susceptible, $S$, infected, $I$, and recovered, $R$, individuals with infection and recovery occurring at a constant transmission rate $\\lambda$ and constant recovery rate $\\mu$, respectively. The model dynamics are governed by the following system of ODEs:\n", "\n", "$$\n", "\\begin{align}\n", From 5b5363cfbdeef73f0dc7d9e2d8001144031444b4 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 27 Apr 2025 13:44:43 +0000 Subject: [PATCH 37/46] ci: update pip via python -m pip pip install -U pip setuptools wheel leads to an error: https://github.com/bayesflow-org/bayesflow/actions/runs/14692655483/job/41230057180?pr=449 --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index ab3d03078..eee02895c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -43,7 +43,7 @@ jobs: - name: Install Dependencies run: | - pip install -U pip setuptools wheel + python -m pip install -U pip setuptools wheel pip install .[test] - name: Install JAX From a322ff176af75ba9ba6c763a009a42095075c10f Mon Sep 17 00:00:00 2001 From: Simon Kucharsky Date: Tue, 29 Apr 2025 20:18:18 +0200 Subject: [PATCH 38/46] Adapter keeps track of the transform jacobians (#419) * minimal working case (.scale) * concatenate * keep, drop, rename * scale, log, sqrt * standardize * constraint transforms * continuous approximator returns log_prob with volume correction * loop for inverse jacobian * inverse for elementwise * inverse for Transforms * raise error with numpy transform (for now) * do not fail if no transform is used * take care of log1p as well * fix filter transforms, boundary condition * add tests for adapter jacobians * document jacobian arg * jacobian -> log_det_jac * add test for inverse concatenation * fix standardize * correct nesting in map_transform --- bayesflow/adapters/adapter.py | 45 ++++++++++---- bayesflow/adapters/transforms/concatenate.py | 34 +++++++++++ bayesflow/adapters/transforms/constrain.py | 29 +++++++++ bayesflow/adapters/transforms/drop.py | 3 + .../transforms/elementwise_transform.py | 3 + .../adapters/transforms/filter_transform.py | 30 ++++++++- bayesflow/adapters/transforms/keep.py | 3 + bayesflow/adapters/transforms/log.py | 9 +++ .../adapters/transforms/map_transform.py | 45 ++++++++++---- .../adapters/transforms/numpy_transform.py | 3 + bayesflow/adapters/transforms/rename.py | 3 + bayesflow/adapters/transforms/scale.py | 7 +++ bayesflow/adapters/transforms/sqrt.py | 6 ++ bayesflow/adapters/transforms/standardize.py | 7 +++ bayesflow/adapters/transforms/transform.py | 5 ++ .../approximators/continuous_approximator.py | 7 ++- tests/test_adapters/conftest.py | 40 ++++++++++++ tests/test_adapters/test_adapters.py | 61 ++++++++++++++++++- 18 files changed, 312 insertions(+), 28 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 4db738eef..ab6800d8a 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -79,7 +79,9 @@ def get_config(self) -> dict: return serialize(config) - def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]: + def forward( + self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs + ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the forward direction. Parameters @@ -88,22 +90,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) - The data to be transformed. stage : str, one of ["training", "validation", "inference"] The stage the function is called in. + log_det_jac: bool, optional + Whether to return the log determinant of the Jacobian of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. Returns ------- - dict - The transformed data. + dict | tuple[dict, dict] + The transformed data or tuple of transformed data and log determinant of the Jacobian. """ data = data.copy() + if not log_det_jac: + for transform in self.transforms: + data = transform(data, stage=stage, **kwargs) + return data + log_det_jac = {} for transform in self.transforms: - data = transform(data, stage=stage, **kwargs) + transformed_data = transform(data, stage=stage, **kwargs) + log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs) + data = transformed_data - return data + return data, log_det_jac - def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]: + def inverse( + self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs + ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the inverse direction. Parameters @@ -112,24 +125,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw The data to be transformed. stage : str, one of ["training", "validation", "inference"] The stage the function is called in. + log_det_jac: bool, optional + Whether to return the log determinant of the Jacobian of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. Returns ------- - dict - The transformed data. + dict | tuple[dict, dict] + The transformed data or tuple of transformed data and log determinant of the Jacobian. """ data = data.copy() + if not log_det_jac: + for transform in reversed(self.transforms): + data = transform(data, stage=stage, inverse=True, **kwargs) + return data + log_det_jac = {} for transform in reversed(self.transforms): data = transform(data, stage=stage, inverse=True, **kwargs) + log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs) - return data + return data, log_det_jac def __call__( self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs - ) -> dict[str, np.ndarray]: + ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the given direction. Parameters @@ -145,8 +166,8 @@ def __call__( Returns ------- - dict - The transformed data. + dict | tuple[dict, dict] + The transformed data or tuple of transformed data and log determinant of the Jacobian. """ if inverse: return self.inverse(data, stage=stage, **kwargs) diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index deb54fc3f..91ea9178b 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -115,3 +115,37 @@ def extra_repr(self) -> str: result += f", axis={self.axis}" return result + + def log_det_jac( + self, + data: dict[str, np.ndarray], + log_det_jac: dict[str, np.ndarray], + *, + strict: bool = False, + inverse: bool = False, + **kwargs, + ) -> dict[str, np.ndarray]: + # copy to avoid side effects + log_det_jac = log_det_jac.copy() + + if inverse: + if log_det_jac.get(self.into) is not None: + raise ValueError( + "Cannot obtain an inverse Jacobian of concatenation. " + "Transform your variables before you concatenate." + ) + + return log_det_jac + + required_keys = set(self.keys) + available_keys = set(log_det_jac.keys()) + common_keys = available_keys & required_keys + + if len(common_keys) == 0: + return log_det_jac + + parts = [log_det_jac.pop(key) for key in common_keys] + + log_det_jac[self.into] = sum(parts) + + return log_det_jac diff --git a/bayesflow/adapters/transforms/constrain.py b/bayesflow/adapters/transforms/constrain.py index 5f93135a1..a4ca0be25 100644 --- a/bayesflow/adapters/transforms/constrain.py +++ b/bayesflow/adapters/transforms/constrain.py @@ -87,6 +87,11 @@ def constrain(x): def unconstrain(x): return inverse_sigmoid((x - lower) / (upper - lower)) + + def ldj(x): + x = (x - lower) / (upper - lower) + return -np.log(x) - np.log1p(-x) - np.log(upper - lower) + case str() as name: raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.") case other: @@ -101,6 +106,11 @@ def constrain(x): def unconstrain(x): return inverse_softplus(x - lower) + + def ldj(x): + x = x - lower + return x - np.log(np.exp(x) - 1) + case "exp" | "log": def constrain(x): @@ -108,6 +118,10 @@ def constrain(x): def unconstrain(x): return np.log(x - lower) + + def ldj(x): + return -np.log(x - lower) + case str() as name: raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.") case other: @@ -122,6 +136,11 @@ def constrain(x): def unconstrain(x): return -inverse_softplus(-(x - upper)) + + def ldj(x): + x = -(x - upper) + return x - np.log(np.exp(x) - 1) + case "exp" | "log": def constrain(x): @@ -129,6 +148,9 @@ def constrain(x): def unconstrain(x): return -np.log(-x + upper) + + def ldj(x): + return -np.log(-x + upper) case str() as name: raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.") case other: @@ -142,6 +164,7 @@ def unconstrain(x): self.constrain = constrain self.unconstrain = unconstrain + self.ldj = ldj # do this last to avoid serialization issues match inclusive: @@ -178,3 +201,9 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: # inverse means network space -> data space, so constrain the data return self.constrain(data) + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + ldj = self.ldj(data) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/drop.py b/bayesflow/adapters/transforms/drop.py index 51615d632..91dcd6a28 100644 --- a/bayesflow/adapters/transforms/drop.py +++ b/bayesflow/adapters/transforms/drop.py @@ -46,3 +46,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]: def extra_repr(self) -> str: return "[" + ", ".join(map(repr, self.keys)) + "]" + + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs): + return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac) diff --git a/bayesflow/adapters/transforms/elementwise_transform.py b/bayesflow/adapters/transforms/elementwise_transform.py index 3bde5a1da..7d603d517 100644 --- a/bayesflow/adapters/transforms/elementwise_transform.py +++ b/bayesflow/adapters/transforms/elementwise_transform.py @@ -25,3 +25,6 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: raise NotImplementedError + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray | None: + return None diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py index e1920e73c..7eccf370b 100644 --- a/bayesflow/adapters/transforms/filter_transform.py +++ b/bayesflow/adapters/transforms/filter_transform.py @@ -150,9 +150,35 @@ def _should_transform(self, key: str, value: np.ndarray, inverse: bool = False) return predicate(key, value, inverse=inverse) def _apply_transform(self, key: str, value: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + transform = self._get_transform(key) + + return transform(value, inverse=inverse, **kwargs) + + def _get_transform(self, key: str) -> ElementwiseTransform: if key not in self.transform_map: self.transform_map[key] = self.transform_constructor(**self.kwargs) - transform = self.transform_map[key] + return self.transform_map[key] - return transform(value, inverse=inverse, **kwargs) + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs + ): + data = data.copy() + + if strict and self.include is not None: + missing_keys = set(self.include) - set(data.keys()) + if missing_keys: + raise KeyError(f"Missing keys from include list: {missing_keys!r}") + + for key, value in data.items(): + if self._should_transform(key, value, inverse=False): + transform = self._get_transform(key) + ldj = transform.log_det_jac(value, **kwargs) + if ldj is None: + continue + elif key in log_det_jac: + log_det_jac[key] += ldj + else: + log_det_jac[key] = ldj + + return log_det_jac diff --git a/bayesflow/adapters/transforms/keep.py b/bayesflow/adapters/transforms/keep.py index 62373071f..56f395166 100644 --- a/bayesflow/adapters/transforms/keep.py +++ b/bayesflow/adapters/transforms/keep.py @@ -57,3 +57,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]: def extra_repr(self) -> str: return "[" + ", ".join(map(repr, self.keys)) + "]" + + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs): + return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac) diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index 3184ab979..d5f559b4f 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -37,3 +37,12 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: def get_config(self) -> dict: return serialize({"p1": self.p1}) + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + if self.p1: + ldj = -np.log1p(data) + else: + ldj = -np.log(data) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/map_transform.py b/bayesflow/adapters/transforms/map_transform.py index 7820ce611..5da8292af 100644 --- a/bayesflow/adapters/transforms/map_transform.py +++ b/bayesflow/adapters/transforms/map_transform.py @@ -41,12 +41,8 @@ def get_config(self) -> dict: def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs) -> dict[str, np.ndarray]: data = data.copy() - required_keys = set(self.transform_map.keys()) - available_keys = set(data.keys()) - missing_keys = required_keys - available_keys - - if strict and missing_keys: - raise KeyError(f"Missing keys: {missing_keys!r}") + if strict: + self._check_keys(data) for key, transform in self.transform_map.items(): if key in data: @@ -57,15 +53,40 @@ def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs) def inverse(self, data: dict[str, np.ndarray], *, strict: bool = False, **kwargs) -> dict[str, np.ndarray]: data = data.copy() - required_keys = set(self.transform_map.keys()) - available_keys = set(data.keys()) - missing_keys = required_keys - available_keys - - if strict and missing_keys: - raise KeyError(f"Missing keys: {missing_keys!r}") + if strict: + self._check_keys(data) for key, transform in self.transform_map.items(): if key in data: data[key] = transform.inverse(data[key], **kwargs) return data + + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs + ) -> dict[str, np.ndarray]: + data = data.copy() + + if strict: + self._check_keys(data) + + for key, transform in self.transform_map.items(): + if key in data: + ldj = transform.log_det_jac(data[key], **kwargs) + + if ldj is None: + continue + elif key in log_det_jac: + log_det_jac[key] += ldj + else: + log_det_jac[key] = ldj + + return log_det_jac + + def _check_keys(self, data: dict[str, np.ndarray]): + required_keys = set(self.transform_map.keys()) + available_keys = set(data.keys()) + missing_keys = required_keys - available_keys + + if missing_keys: + raise KeyError(f"Missing keys: {missing_keys!r}") diff --git a/bayesflow/adapters/transforms/numpy_transform.py b/bayesflow/adapters/transforms/numpy_transform.py index aecf03bba..29d25dc67 100644 --- a/bayesflow/adapters/transforms/numpy_transform.py +++ b/bayesflow/adapters/transforms/numpy_transform.py @@ -72,3 +72,6 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return self._inverse(data) + + def log_det_jac(self, data, inverse=False, **kwargs): + raise NotImplementedError("log determinant of the Jacobian of the numpy transforms are not implemented yet") diff --git a/bayesflow/adapters/transforms/rename.py b/bayesflow/adapters/transforms/rename.py index 49cc52eba..746ef5a80 100644 --- a/bayesflow/adapters/transforms/rename.py +++ b/bayesflow/adapters/transforms/rename.py @@ -58,3 +58,6 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di def extra_repr(self) -> str: return f"{self.from_key!r} -> {self.to_key!r}" + + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs): + return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False) diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py index 8d9dce1be..96b2ff927 100644 --- a/bayesflow/adapters/transforms/scale.py +++ b/bayesflow/adapters/transforms/scale.py @@ -18,3 +18,10 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return data / self.scale + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + ldj = np.log(np.abs(self.scale)) + ldj = np.full(data.shape, ldj) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index 617f892bc..4ef1370dc 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -22,3 +22,9 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: def get_config(self) -> dict: return {} + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + ldj = -0.5 * np.log(data) - np.log(2) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/standardize.py b/bayesflow/adapters/transforms/standardize.py index 52899917c..9699819b9 100644 --- a/bayesflow/adapters/transforms/standardize.py +++ b/bayesflow/adapters/transforms/standardize.py @@ -120,3 +120,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: std = np.broadcast_to(self.std, data.shape) return data * std + mean + + def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray: + std = np.broadcast_to(self.std, data.shape) + ldj = np.log(np.abs(std)) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/transform.py b/bayesflow/adapters/transforms/transform.py index 4642c1165..ed3058e15 100644 --- a/bayesflow/adapters/transforms/transform.py +++ b/bayesflow/adapters/transforms/transform.py @@ -35,3 +35,8 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray def extra_repr(self) -> str: return "" + + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], inverse: bool = False, **kwargs + ) -> dict[str, np.ndarray]: + return log_det_jac diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index dcb661ca0..bf4e263a0 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -417,11 +417,16 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic np.ndarray Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))` """ - data = self.adapter(data, strict=False, stage="inference", **kwargs) + data, log_det_jac = self.adapter(data, strict=False, stage="inference", log_det_jac=True, **kwargs) data = keras.tree.map_structure(keras.ops.convert_to_tensor, data) log_prob = self._log_prob(**data, **kwargs) log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob) + # change of variables formula + log_det_jac = log_det_jac.get("inference_variables") + if log_det_jac is not None: + log_prob = log_prob + log_det_jac + return log_prob def _log_prob( diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 873279f09..d69cd4be4 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -49,6 +49,8 @@ def random_data(): "z1": np.random.standard_normal(size=(32, 2)), "p1": np.random.lognormal(size=(32, 2)), "p2": np.random.lognormal(size=(32, 2)), + "p3": np.random.lognormal(size=(32, 2)), + "n1": 1 - np.random.lognormal(size=(32, 2)), "s1": np.random.standard_normal(size=(32, 3, 2)), "s2": np.random.standard_normal(size=(32, 3, 2)), "t1": np.zeros((3, 2)), @@ -56,5 +58,43 @@ def random_data(): "d1": np.random.standard_normal(size=(32, 2)), "d2": np.random.standard_normal(size=(32, 2)), "o1": np.random.randint(0, 9, size=(32, 2)), + "u1": np.random.uniform(low=-1, high=2, size=(32, 1)), "key_to_split": np.random.standard_normal(size=(32, 10)), } + + +@pytest.fixture() +def adapter_log_det_jac(): + from bayesflow.adapters import Adapter + + adapter = ( + Adapter() + .scale("x1", by=2) + .log("p1", p1=True) + .sqrt("p2") + .constrain("p3", lower=0) + .constrain("n1", upper=1) + .constrain("u1", lower=-1, upper=2) + .concatenate(["p1", "p2", "p3"], into="p") + .rename("u1", "u") + ) + + return adapter + + +@pytest.fixture() +def adapter_log_det_jac_inverse(): + from bayesflow.adapters import Adapter + + adapter = ( + Adapter() + .standardize("x1", mean=1, std=2) + .log("p1") + .sqrt("p2") + .constrain("p3", lower=0, method="log") + .constrain("n1", upper=1, method="log") + .constrain("u1", lower=-1, upper=2) + .scale(["p1", "p2", "p3"], by=3.5) + ) + + return adapter diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index d6215170e..1784befb7 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -13,7 +13,7 @@ def test_cycle_consistency(adapter, random_data): deprocessed = adapter(processed, inverse=True) for key, value in random_data.items(): - if key in ["d1", "d2"]: + if key in ["d1", "d2", "p3", "n1", "u1"]: # dropped continue assert key in deprocessed @@ -230,3 +230,62 @@ def test_to_dict_transform(): # category should have 5 one-hot categories, even though it was only passed 4 assert processed["category"].shape[-1] == 5 + + +def test_log_det_jac(adapter_log_det_jac, random_data): + d, log_det_jac = adapter_log_det_jac(random_data, log_det_jac=True) + + assert np.allclose(log_det_jac["x1"], np.log(2)) + + p1 = -np.log1p(random_data["p1"]) + p2 = -0.5 * np.log(random_data["p2"]) - np.log(2) + p3 = random_data["p3"] - np.log(np.exp(random_data["p3"]) - 1) + p = np.sum(p1, axis=-1) + np.sum(p2, axis=-1) + np.sum(p3, axis=-1) + + assert np.allclose(log_det_jac["p"], p) + + n1 = -(random_data["n1"] - 1) + n1 = n1 - np.log(np.exp(n1) - 1) + n1 = np.sum(n1, axis=-1) + + assert np.allclose(log_det_jac["n1"], n1) + + u1 = random_data["u1"] + u1 = (u1 + 1) / 3 + u1 = -np.log(u1) - np.log1p(-u1) - np.log(3) + + assert np.allclose(log_det_jac["u"], u1[:, 0]) + + +def test_log_det_jac_inverse(adapter_log_det_jac_inverse, random_data): + d, forward_log_det_jac = adapter_log_det_jac_inverse(random_data, log_det_jac=True) + d, inverse_log_det_jac = adapter_log_det_jac_inverse(d, inverse=True, log_det_jac=True) + + for key in forward_log_det_jac.keys(): + assert np.allclose(forward_log_det_jac[key], -inverse_log_det_jac[key]) + + +def test_log_det_jac_exceptions(random_data): + # Test cannot compute inverse log_det_jac + # e.g., when we apply a concat and then a transform that + # we cannot "unconcatenate" the log_det_jac + # (because the log_det_jac are summed, not concatenated) + adapter = bf.Adapter().concatenate(["p1", "p2", "p3"], into="p").sqrt("p") + transformed_data, log_det_jac = adapter(random_data, log_det_jac=True) + + # test that inverse raises error + with pytest.raises(ValueError): + adapter(transformed_data, inverse=True, log_det_jac=True) + + # test resolvable order: first transform, then concatenate + adapter = bf.Adapter().sqrt(["p1", "p2", "p3"]).concatenate(["p1", "p2", "p3"], into="p") + + transformed_data, forward_log_det_jac = adapter(random_data, log_det_jac=True) + data, inverse_log_det_jac = adapter(transformed_data, inverse=True, log_det_jac=True) + inverse_log_det_jac = sum(inverse_log_det_jac.values()) + + # forward is the same regardless + assert np.allclose(forward_log_det_jac["p"], log_det_jac["p"]) + + # inverse works when concatenation is used after transforms + assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac) From 5595eabaa08f2e6bb036106e17164e8af9b47e8e Mon Sep 17 00:00:00 2001 From: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Date: Tue, 29 Apr 2025 20:22:15 +0200 Subject: [PATCH 39/46] Deprecate old serialization (`(de)serialize_value_or_type`) and add developer docs (#450) * document and expose serialization module * add functools.wraps call to allow_kwargs decorator, as before it was breaking the autodoc functionality * restructure and update developer docs * move content to separate pages * update section on serialization * ci: update pip via python -m pip pip install -U pip setuptools wheel leads to an error: https://github.com/bayesflow-org/bayesflow/actions/runs/14692655483/job/41230057180?pr=449 * serializable: increase depth in sys._getframe The functools.wrap decorator adds a frame object to the call stack * deprecate (de)serialize_value_or_type - add deprecation warning, remove functionality - replace all occurences with the corresponding new functions --- bayesflow/networks/point_inference_network.py | 14 +- bayesflow/scores/scoring_rule.py | 19 +-- bayesflow/utils/__init__.py | 3 +- bayesflow/utils/decorators.py | 1 + bayesflow/utils/serialization.py | 152 ++++++++++-------- docsrc/source/development/index.md | 90 ++--------- docsrc/source/development/introduction.md | 12 ++ docsrc/source/development/pitfalls.md | 13 ++ docsrc/source/development/serialization.md | 26 +++ docsrc/source/development/stages.md | 8 + 10 files changed, 169 insertions(+), 169 deletions(-) create mode 100644 docsrc/source/development/introduction.md create mode 100644 docsrc/source/development/pitfalls.md create mode 100644 docsrc/source/development/serialization.md create mode 100644 docsrc/source/development/stages.md diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py index 3b1699e5a..63094a2a8 100644 --- a/bayesflow/networks/point_inference_network.py +++ b/bayesflow/networks/point_inference_network.py @@ -1,11 +1,7 @@ import keras -from keras.saving import ( - deserialize_keras_object as deserialize, - serialize_keras_object as serialize, - register_keras_serializable as serializable, -) -from bayesflow.utils import model_kwargs, find_network, serialize_value_or_type, deserialize_value_or_type +from bayesflow.utils import model_kwargs, find_network +from bayesflow.utils.serialization import deserialize, serializable, serialize from bayesflow.types import Shape, Tensor from bayesflow.scores import ScoringRule, ParametricDistributionScore from bayesflow.utils.decorators import allow_batch_size @@ -30,10 +26,10 @@ def __init__( self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {})) self.config = { + "subnet": serialize(subnet), + "scores": serialize(scores), **kwargs, } - self.config = serialize_value_or_type(self.config, "subnet", subnet) - self.config["scores"] = serialize(self.scores) def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: """Builds all network components based on shapes of conditions and targets. @@ -119,7 +115,7 @@ def get_config(self): def from_config(cls, config): config = config.copy() config["scores"] = deserialize(config["scores"]) - config = deserialize_value_or_type(config, "subnet") + config["subnet"] = deserialize(config["subnet"]) return cls(**config) def call( diff --git a/bayesflow/scores/scoring_rule.py b/bayesflow/scores/scoring_rule.py index a1a3f5717..0144de458 100644 --- a/bayesflow/scores/scoring_rule.py +++ b/bayesflow/scores/scoring_rule.py @@ -1,10 +1,10 @@ import math import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor -from bayesflow.utils import find_network, serialize_value_or_type, deserialize_value_or_type +from bayesflow.utils import find_network +from bayesflow.utils.serialization import deserialize, serializable, serialize @serializable(package="bayesflow.scores") @@ -51,23 +51,16 @@ def __init__( self.config = {"subnets_kwargs": self.subnets_kwargs} def get_config(self): - self.config["subnets"] = { - key: serialize_value_or_type({}, "subnet", subnet) for key, subnet in self.subnets.items() - } - self.config["links"] = {key: serialize_value_or_type({}, "link", link) for key, link in self.links.items()} + self.config["subnets"] = {key: serialize(subnet) for key, subnet in self.subnets.items()} + self.config["links"] = {key: serialize(link) for key, link in self.links.items()} return self.config @classmethod def from_config(cls, config): config = config.copy() - config["subnets"] = { - key: deserialize_value_or_type(subnet_dict, "subnet")["subnet"] - for key, subnet_dict in config["subnets"].items() - } - config["links"] = { - key: deserialize_value_or_type(link_dict, "link")["link"] for key, link_dict in config["links"].items() - } + config["subnets"] = {key: deserialize(subnet_dict) for key, subnet_dict in config["subnets"].items()} + config["links"] = {key: deserialize(link_dict) for key, link_dict in config["links"].items()} return cls(**config) diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 737c533ce..47ab771ff 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -6,6 +6,7 @@ keras_utils, logging, numpy_utils, + serialization, ) from .callbacks import detailed_loss_callback @@ -104,4 +105,4 @@ from ._docs import _add_imports_to_all -_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils"]) +_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization"]) diff --git a/bayesflow/utils/decorators.py b/bayesflow/utils/decorators.py index 7fd32edc9..1283fe66a 100644 --- a/bayesflow/utils/decorators.py +++ b/bayesflow/utils/decorators.py @@ -17,6 +17,7 @@ def allow_args(fn: Decorator) -> Decorator: def wrapper(f: Fn) -> Fn: ... @overload def wrapper(*fargs: any, **fkwargs: any) -> Fn: ... + @wraps(fn) def wrapper(*fargs: any, **fkwargs: any) -> Fn: if len(fargs) == 1 and not fkwargs and callable(fargs[0]): # called without arguments diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index 500264f05..bb55aee41 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -5,6 +5,7 @@ import keras import numpy as np import sys +from warnings import warn # this import needs to be exactly like this to work with monkey patching from keras.saving import deserialize_keras_object @@ -19,109 +20,102 @@ def serialize_value_or_type(config, name, obj): - """Serialize an object that can be either a value or a type - and add it to a copy of the supplied dictionary. + """This function is deprecated.""" + warn( + "This method is deprecated. It was replaced by bayesflow.utils.serialization.serialize.", + DeprecationWarning, + stacklevel=2, + ) - Parameters - ---------- - config : dict - Dictionary to add the serialized object to. This function does not - modify the dictionary in place, but returns a modified copy. - name : str - Name of the obj that should be stored. Required for later deserialization. - obj : object or type - The object to serialize. If `obj` is of type `type`, we use - `keras.saving.get_registered_name` to obtain the registered type name. - If it is not a type, we try to serialize it as a Keras object. - Returns - ------- - updated_config : dict - Updated dictionary with a new key `"_bayesflow__type"` or - `"_bayesflow__val"`. The prefix is used to avoid name collisions, - the suffix indicates how the stored value has to be deserialized. - - Notes - ----- - We allow strings or `type` parameters at several places to instantiate objects - of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot - be serialized, we have to distinguish the two cases for serialization and - deserialization. This function is a helper function to standardize and - simplify this. - """ - updated_config = config.copy() - if isinstance(obj, type): - updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj) - else: - updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj) - return updated_config +def deserialize_value_or_type(config, name): + """This function is deprecated.""" + warn( + "This method is deprecated. It was replaced by bayesflow.utils.serialization.deserialize.", + DeprecationWarning, + stacklevel=2, + ) -def deserialize_value_or_type(config, name): - """Deserialize an object that can be either a value or a type and add - it to the supplied dictionary. +def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs): + """Deserialize an object serialized with :py:func:`serialize`. + + Wrapper function around `keras.saving.deserialize_keras_object` to enable deserialization of + classes. Parameters ---------- - config : dict - Dictionary containing the object to deserialize. If a type was - serialized, it should contain the key `"_bayesflow__type"`. - If an object was serialized, it should contain the key - `"_bayesflow__val"`. In a copy of this dictionary, - the item will be replaced with the key `name`. - name : str - Name of the object to deserialize. + config : dict + Python dict describing the object. + custom_objects : dict, optional + Python dict containing a mapping between custom object names and the corresponding + classes or functions. Forwarded to `keras.saving.deserialize_keras_object`. + safe_mode : bool, optional + Boolean, whether to disallow unsafe lambda deserialization. When safe_mode=False, + loading an object has the potential to trigger arbitrary code execution. This argument + is only applicable to the Keras v3 model format. Defaults to True. + Forwarded to `keras.saving.deserialize_keras_object`. Returns ------- - updated_config : dict - Updated dictionary with a new key `name`, with a value that is either - a type or an object. + obj : + The object described by the config dictionary. + + Raises + ------ + ValueError + If a type in the config can not be deserialized. See Also -------- - serialize_value_or_type + serialize """ - updated_config = config.copy() - if f"{PREFIX}{name}_type" in config: - updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"]) - del updated_config[f"{PREFIX}{name}_type"] - elif f"{PREFIX}{name}_val" in config: - updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"]) - del updated_config[f"{PREFIX}{name}_val"] - return updated_config - - -def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs): with monkey_patch(deserialize_keras_object, deserialize) as original_deserialize: - if isinstance(obj, str) and obj.startswith(_type_prefix): + if isinstance(config, str) and config.startswith(_type_prefix): # we marked this as a type during serialization - obj = obj[len(_type_prefix) :] + config = config[len(_type_prefix) :] tp = keras.saving.get_registered_object( # TODO: can we pass module objects without overwriting numpy's dict with builtins? - obj, + config, custom_objects=custom_objects, module_objects=np.__dict__ | builtins.__dict__, ) if tp is None: raise ValueError( - f"Could not deserialize type {obj!r}. Make sure it is registered with " + f"Could not deserialize type {config!r}. Make sure it is registered with " f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`." ) return tp - if inspect.isclass(obj): + if inspect.isclass(config): # add this base case since keras does not cover it - return obj + return config - obj = original_deserialize(obj, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs) + obj = original_deserialize(config, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs) return obj @allow_args -def serializable(cls, package=None, name=None): +def serializable(cls, package: str | None = None, name: str | None = None): + """Register class as Keras serialize. + + Wrapper function around `keras.saving.register_keras_serializable` to automatically + set the `package` and `name` arguments. + + Parameters + ---------- + cls : type + The class to register. + package : str, optional + `package` argument forwarded to `keras.saving.register_keras_serializable`. + If None is provided, the package is automatically inferred using the __name__ + attribute of the module the class resides in. + name : str, optional + `name` argument forwarded to `keras.saving.register_keras_serializable`. + If None is provided, the classe's __name__ attribute is used. + """ if package is None: - frame = sys._getframe(1) + frame = sys._getframe(2) g = frame.f_globals package = g.get("__name__", "bayesflow") @@ -133,6 +127,26 @@ def serializable(cls, package=None, name=None): def serialize(obj): + """Serialize an object using Keras. + + Wrapper function around `keras.saving.serialize_keras_object`, which adds the + ability to serialize classes. + + Parameters + ---------- + object : Keras serializable object, or class + The object to serialize + + Returns + ------- + config : dict + A python dict that represents the object. The python dict can be deserialized via + :py:func:`deserialize`. + + See Also + -------- + deserialize + """ if isinstance(obj, (tuple, list, dict)): return keras.tree.map_structure(serialize, obj) elif inspect.isclass(obj): diff --git a/docsrc/source/development/index.md b/docsrc/source/development/index.md index c62971532..adbadf21f 100644 --- a/docsrc/source/development/index.md +++ b/docsrc/source/development/index.md @@ -1,87 +1,23 @@ -# Patterns & Caveats +# Developer Documentation -**Note**: This document is part of BayesFlow's developer documentation, and +**Attention:** You are looking BayesFlow's developer documentation, which is aimed at people who want to extend or improve BayesFlow. For user documentation, -please refer to the examples and the public API documentation. +please refer to the {doc}`../examples` and the {doc}`../api/bayesflow`. -## Introduction - -From version 2 on, BayesFlow is built on [Keras](https://keras.io/) v3, which -allows writing machine learning pipelines that run in JAX, TensorFlow and PyTorch. -By using functionality provided by Keras, and extending it with backend-specific -code where necessary, we aim to build BayesFlow in a backend-agnostic fashion as -well. - -As Keras is built upon three different backend, each with different functionality -and design decisions, it has its own quirks and compromises. This documents -outlines some of them, along with the design decisions and programming patterns -we use to counter them. - -This document is work in progress, so if you read through the code base and +This section is work in progress, so if you read through the code base and encounter something that looks odd, but shows up in multiple places, please open an issue so that we can add it here. Also, if you introduce a new pattern that others will have to use in the future as well, please document it here, along with some background information on why it is necessary and how to use it in practice. -## Privileged `training` argument in the `call()` method cannot be passed via `kwargs` - -For layers that have different behavior at training and inference time (e.g., -dropout or batch normalization layers), a boolean `training` argument can be -exposed, see [this section of the Keras documentation](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#privileged-training-argument-in-the-call-method). -If we want to pass this manually, we have to do so explicitly and not as part -of a set of keyword arguments via `**kwargs`. - -@Lars: Maybe you can add more details on what is going on behind the scenes. - -## Serialization - -Serialization deals with the problem of storing objects to disk, and loading -them at a later point in time. This is straight-forward for data structures like -numpy arrays, but for classes with custom behavior, like approximators or neural -network layers, it is somewhat more complex. - -Please refer to the Keras guide [Save, serialize, and export models](https://keras.io/guides/serialization_and_saving/) -for an introduction, and [Customizing Saving and Serialization](https://keras.io/guides/customizing_saving_and_serialization/) -for advanced concepts. - -The basic idea is: by storing the arguments of the constructor of a class -(i.e., the arguments of the `__init__` function), we can later construct an -object identical to the one we have stored, except for the weights. -As the structure is identical, we can then map the stored weights to the newly -constructed object. The caveat is that all arguments have to be either basic -Python objects (like int, float, string, bool, ...) or themselves serializable. -If they are not, we have to manually specify how to serialize them, and how to -load them later on. - -### Registering classes as serializable - -TODO - -### Serialization of custom types - -In BayesFlow, we often encounter situations where we do not want to pass a -specific object (e.g., an MPL of a certain size), but we want to pass its type -(MLP) and the arguments to construct it. With the type and the arguments, we can -then construct multiple instances of the network in different places, for example -as the network inside a coupling block. - -Unfortunately, `type` is not Keras serializable, so we have to serialize those -arguments manually. To complicate matters further, we also allow passing a string -instead of a type, which is then used to select the correct type. - -To make it more concrete, we look at the `CouplingFlow` class, which takes the -argument `subnet` that provide the type of the subnet. It is either a -string (e.g., `"mlp"`) or a class (e.g., `bayesflow.networks.MLP`). In the first -case, we can just store the value and load it, in the latter case, we first have -to convert the type to a string that we can later convert back into a type. - -We provide two helper functions that can deal with both cases: -`bayesflow.utils.serialize_value_or_type(config, name, obj)` and -`bayesflow.utils.deserialize_value_or_type(config, name)`. -In `get_config`, we use the first to store the object, whereas we use the -latter in `from_config` to load it again. +```{toctree} +:maxdepth: 1 +:titlesonly: +:numbered: -As we need all arguments to `__init__` in `get_config`, it can make sense to -build a `config` dictionary in `__init__` already, which can then be stored when -`get_config` is called. Take a look at `CouplingFlow` for an example of that. +introduction +pitfalls +stages +serialization +``` diff --git a/docsrc/source/development/introduction.md b/docsrc/source/development/introduction.md new file mode 100644 index 000000000..a60830c2a --- /dev/null +++ b/docsrc/source/development/introduction.md @@ -0,0 +1,12 @@ +# Introduction + +From version 2 on, BayesFlow is built on [Keras3](https://keras.io/), which +allows writing machine learning pipelines that run in JAX, TensorFlow and PyTorch. +By using functionality provided by Keras, and extending it with backend-specific +code where necessary, we aim to build BayesFlow in a backend-agnostic fashion as +well. + +As Keras is built upon three different backends, each with different functionality +and design decisions, it comes with its own quirks and compromises. The following documents +outline some of them, along with the design decisions and programming patterns +we use to counter them. diff --git a/docsrc/source/development/pitfalls.md b/docsrc/source/development/pitfalls.md new file mode 100644 index 000000000..69d183ec1 --- /dev/null +++ b/docsrc/source/development/pitfalls.md @@ -0,0 +1,13 @@ +# Potential Pitfalls + +This document covers things we have learned during development that might cause problems or hard to find bugs. + +## Privileged `training` argument in the `call()` method cannot be passed via `kwargs` + +For layers that have different behavior at training and inference time (e.g., +dropout or batch normalization layers), a boolean `training` argument can be +exposed, see [this section of the Keras documentation](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#privileged-training-argument-in-the-call-method). +If we want to pass this manually, we have to do so explicitly and not as part +of a set of keyword arguments via `**kwargs`. + +@Lars: Maybe you can add more details on what is going on behind the scenes. diff --git a/docsrc/source/development/serialization.md b/docsrc/source/development/serialization.md new file mode 100644 index 000000000..15c812686 --- /dev/null +++ b/docsrc/source/development/serialization.md @@ -0,0 +1,26 @@ +# Serialization: Enable Model Saving & Loading + +Serialization deals with the problem of storing objects to disk, and loading them at a later point in time. +This is straight-forward for data structures like numpy arrays, but for classes with custom behavior it is somewhat more complex. + +Please refer to the Keras guide [Save, serialize, and export models](https://keras.io/guides/serialization_and_saving/) for an introduction, and [Customizing Saving and Serialization](https://keras.io/guides/customizing_saving_and_serialization/) for advanced concepts. + +The basic idea is: by storing the arguments of the constructor of a class (i.e., the arguments of the `__init__` function), we can later construct an object similar to the one we have stored, except for the weights and other stateful content. +As the structure is identical, we can then map the stored weights to the newly constructed object. +The caveat is that all arguments have to be either basic Python objects (like int, float, string, bool, ...) or themselves serializable. +If they are not, we have to manually specify how to serialize them, and how to load them later on. +One important example is that types are not serializable. +As we want/need to pass them in some places, we have to resort to some custom behavior, that is described below. + +## Serialization Utilities + +BayesFlows serialization utilities can be found in the {py:mod}`~bayesflow.utils.serialization` module. +We mainly provide three convenience functions: + +- The {py:func}`~bayesflow.utils.serialization.serializable` decorator wraps the `keras.saving.register_keras_serializable` function to provide automatic `package` and `name` arguments. +- The {py:func}`~bayesflow.utils.serialization.serialize` function, which adds support for serializing classes. +- Its counterpart {py:func}`~bayesflow.utils.serialization.deserialize`, adds support to deserialize classes. + +## Usage + +To use the adapted serialization functions, you have to use them in the `get_config` and `from_config` method. Please refer to existing classes in the library for usage examples. diff --git a/docsrc/source/development/stages.md b/docsrc/source/development/stages.md new file mode 100644 index 000000000..e9aa4ad8f --- /dev/null +++ b/docsrc/source/development/stages.md @@ -0,0 +1,8 @@ +# Stages + +To keep track of the phase each functionality is called in, we provide a `stage` parameter. +There are three stages: + +- `training`: The stage to train approximator (and related stateful objects, like the adapter) +- `validation`: Identical setting to `training`, but calls in this stage should _not_ change the approximator +- `inference`: Calls in this change should not change the approximator. In addition, the input structure might be different compared to the training phase. For example for sampling, we only provide `summary_conditions` and `inference_conditions`, but not the `inference_variables`, which we want to infer. From 62675c33f51f8c6f531046be3e767ad07b1890fb Mon Sep 17 00:00:00 2001 From: The-Gia Leo Nguyen Date: Fri, 2 May 2025 09:14:15 +0200 Subject: [PATCH 40/46] Implement Feature for Issue #379: MMD Hypothesis Test (#384) - add `bootstrap_comparison` and `summary_space_comparison` to enable comparisons of two domains in the data space or the summary space via bootstrapping - add `.summaries()` function for easy access to summaries to `ContinuousApproximator` and `ModelComparisonApproximator` - add tests for the added functionality --------- Co-authored-by: Valentin Pratz --- .../approximators/continuous_approximator.py | 33 ++ .../model_comparison_approximator.py | 33 ++ bayesflow/diagnostics/__init__.py | 9 +- bayesflow/diagnostics/metrics/__init__.py | 1 + .../metrics/model_misspecification.py | 155 ++++++++++ tests/test_approximators/conftest.py | 31 ++ tests/test_approximators/test_summaries.py | 23 ++ tests/test_diagnostics/conftest.py | 14 + .../test_diagnostics_metrics.py | 290 +++++++++++++++++- tests/utils/__init__.py | 1 + tests/utils/networks.py | 8 + 11 files changed, 595 insertions(+), 3 deletions(-) create mode 100644 bayesflow/diagnostics/metrics/model_misspecification.py create mode 100644 tests/test_approximators/test_summaries.py create mode 100644 tests/utils/networks.py diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index bf4e263a0..834521d4b 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -400,6 +400,39 @@ def _sample( **filter_kwargs(kwargs, self.inference_network.sample), ) + def summaries(self, data: Mapping[str, np.ndarray], **kwargs): + """ + Computes the summaries of given data. + + The `data` dictionary is preprocessed using the `adapter` and passed through the summary network. + + Parameters + ---------- + data : Mapping[str, np.ndarray] + Dictionary of data as NumPy arrays. + **kwargs : dict + Additional keyword arguments for the adapter and the summary network. + + Returns + ------- + summaries : np.ndarray + Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))` + + Raises + ------ + ValueError + If the approximator does not have a summary network, or the adapter does not produce the output required + by the summary network. + """ + if self.summary_network is None: + raise ValueError("A summary network is required to compute summeries.") + data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs) + if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None: + raise ValueError("Summary variables are required to compute summaries.") + summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"]) + summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call)) + return summaries + def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]: """ Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 1b9d198ff..028e8837a 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -345,3 +345,36 @@ def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tens output = self.logits_projector(output) return output + + def summaries(self, data: Mapping[str, np.ndarray], **kwargs): + """ + Computes the summaries of given data. + + The `data` dictionary is preprocessed using the `adapter` and passed through the summary network. + + Parameters + ---------- + data : Mapping[str, np.ndarray] + Dictionary of data as NumPy arrays. + **kwargs : dict + Additional keyword arguments for the adapter and the summary network. + + Returns + ------- + summaries : np.ndarray + Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))` + + Raises + ------ + ValueError + If the approximator does not have a summary network, or the adapter does not produce the output required + by the summary network. + """ + if self.summary_network is None: + raise ValueError("A summary network is required to compute summaries.") + data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs) + if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None: + raise ValueError("Summary variables are required to compute summaries.") + summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"]) + summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call)) + return summaries diff --git a/bayesflow/diagnostics/__init__.py b/bayesflow/diagnostics/__init__.py index 1e13e11f2..87823c754 100644 --- a/bayesflow/diagnostics/__init__.py +++ b/bayesflow/diagnostics/__init__.py @@ -1,8 +1,13 @@ -""" +r""" A collection of plotting utilities and metrics for evaluating trained :py:class:`~bayesflow.workflows.Workflow`\ s. """ -from .metrics import root_mean_squared_error, calibration_error, posterior_contraction +from .metrics import ( + bootstrap_comparison, + calibration_error, + posterior_contraction, + summary_space_comparison, +) from .plots import ( calibration_ecdf, diff --git a/bayesflow/diagnostics/metrics/__init__.py b/bayesflow/diagnostics/metrics/__init__.py index ceeca4cc4..3e3496cda 100644 --- a/bayesflow/diagnostics/metrics/__init__.py +++ b/bayesflow/diagnostics/metrics/__init__.py @@ -3,3 +3,4 @@ from .root_mean_squared_error import root_mean_squared_error from .expected_calibration_error import expected_calibration_error from .classifier_two_sample_test import classifier_two_sample_test +from .model_misspecification import bootstrap_comparison, summary_space_comparison diff --git a/bayesflow/diagnostics/metrics/model_misspecification.py b/bayesflow/diagnostics/metrics/model_misspecification.py new file mode 100644 index 000000000..c698d4eb2 --- /dev/null +++ b/bayesflow/diagnostics/metrics/model_misspecification.py @@ -0,0 +1,155 @@ +""" +This module provides functions for computing distances between observation samples and reference samples with distance +distributions within the reference samples for hypothesis testing. +""" + +from collections.abc import Mapping, Callable + +import numpy as np +from keras.ops import convert_to_numpy, convert_to_tensor + +from bayesflow.approximators import ContinuousApproximator +from bayesflow.metrics.functional import maximum_mean_discrepancy +from bayesflow.types import Tensor + + +def bootstrap_comparison( + observed_samples: np.ndarray, + reference_samples: np.ndarray, + comparison_fn: Callable[[Tensor, Tensor], Tensor], + num_null_samples: int = 100, +) -> tuple[float, np.ndarray]: + """Computes the distance between observed and reference samples and generates a distribution of null sample + distances by bootstrapping for hypothesis testing. + + Parameters + ---------- + observed_samples : np.ndarray) + Observed samples, shape (num_observed, ...). + reference_samples : np.ndarray + Reference samples, shape (num_reference, ...). + comparison_fn : Callable[[Tensor, Tensor], Tensor] + Function to compute the distance metric. + num_null_samples : int + Number of null samples to generate for hypothesis testing. Default is 100. + + Returns + ------- + distance_observed : float + The distance value between observed and reference samples. + distance_null : np.ndarray + A distribution of distance values under the null hypothesis. + + Raises + ------ + ValueError + - If the number of number of observed samples exceeds the number of reference samples + - If the shapes of observed and reference samples do not match on dimensions besides the first one. + """ + num_observed: int = observed_samples.shape[0] + num_reference: int = reference_samples.shape[0] + + if num_observed > num_reference: + raise ValueError( + f"Number of observed samples ({num_observed}) cannot exceed" + f"the number of reference samples ({num_reference}) for bootstrapping." + ) + if observed_samples.shape[1:] != reference_samples.shape[1:]: + raise ValueError( + f"Expected observed and reference samples to have the same shape, " + f"but got {observed_samples.shape[1:]} != {reference_samples.shape[1:]}." + ) + + observed_samples_tensor: Tensor = convert_to_tensor(observed_samples, dtype="float32") + reference_samples_tensor: Tensor = convert_to_tensor(reference_samples, dtype="float32") + + distance_null_samples: np.ndarray = np.zeros(num_null_samples, dtype=np.float64) + for i in range(num_null_samples): + bootstrap_idx: np.ndarray = np.random.randint(0, num_reference, size=num_observed) + bootstrap_samples: np.ndarray = reference_samples[bootstrap_idx] + bootstrap_samples_tensor: Tensor = convert_to_tensor(bootstrap_samples, dtype="float32") + distance_null_samples[i] = convert_to_numpy(comparison_fn(bootstrap_samples_tensor, reference_samples_tensor)) + + distance_observed_tensor: Tensor = comparison_fn( + observed_samples_tensor, + reference_samples_tensor, + ) + + distance_observed: float = float(convert_to_numpy(distance_observed_tensor)) + + return distance_observed, distance_null_samples + + +def summary_space_comparison( + observed_data: Mapping[str, np.ndarray], + reference_data: Mapping[str, np.ndarray], + approximator: ContinuousApproximator, + num_null_samples: int = 100, + comparison_fn: Callable = maximum_mean_discrepancy, + **kwargs, +) -> tuple[float, np.ndarray]: + """Computes the distance between observed and reference data in the summary space and + generates a distribution of distance values under the null hypothesis to assess model misspecification. + + By default, the Maximum Mean Discrepancy (MMD) is used as a distance function. + + [1] M. Schmitt, P.-C. Bürkner, U. Köthe, and S. T. Radev, "Detecting model misspecification in amortized Bayesian + inference with neural networks," arXiv e-prints, Dec. 2021, Art. no. arXiv:2112.08866. + URL: https://arxiv.org/abs/2112.08866 + + Parameters + ---------- + observed_data : dict[str, np.ndarray] + Dictionary of observed data as NumPy arrays, which will be preprocessed by the approximators adapter and passed + through its summary network. + reference_data : dict[str, np.ndarray] + Dictionary of reference data as NumPy arrays, which will be preprocessed by the approximators adapter and passed + through its summary network. + approximator : ContinuousApproximator + An instance of :py:class:`~bayesflow.approximators.ContinuousApproximator` used to compute summary statistics + from the data. + num_null_samples : int, optional + Number of null samples to generate for hypothesis testing. Default is 100. + comparison_fn : Callable, optional + Distance function to compare the data in the summary space. + **kwargs : dict + Additional keyword arguments for the adapter and sampling process. + + Returns + ------- + distance_observed : float + The MMD value between observed and reference summaries. + distance_null : np.ndarray + A distribution of MMD values under the null hypothesis. + + Raises + ------ + ValueError + If approximator is not an instance of ContinuousApproximator or does not have a summary network. + """ + + if not isinstance(approximator, ContinuousApproximator): + raise ValueError("The approximator must be an instance of ContinuousApproximator.") + + if not hasattr(approximator, "summary_network") or approximator.summary_network is None: + comparison_fn_name = ( + "bayesflow.metrics.functional.maximum_mean_discrepancy" + if comparison_fn is maximum_mean_discrepancy + else comparison_fn.__name__ + ) + raise ValueError( + "The approximator must have a summary network. If you have manually crafted summary " + "statistics, or want to compare raw data and not summary statistics, please use the " + f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays." + ) + observed_summaries = convert_to_numpy(approximator.summaries(observed_data)) + reference_summaries = convert_to_numpy(approximator.summaries(reference_data)) + + distance_observed, distance_null = bootstrap_comparison( + observed_samples=observed_summaries, + reference_samples=reference_summaries, + comparison_fn=comparison_fn, + num_null_samples=num_null_samples, + ) + + return distance_observed, distance_null diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 125371a52..227e70ff1 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -163,3 +163,34 @@ def validation_dataset(batch_size, adapter, simulator): num_batches = 2 data = simulator.sample((num_batches * batch_size,)) return OfflineDataset(data=data, adapter=adapter, batch_size=batch_size, workers=4, max_queue_size=num_batches) + + +@pytest.fixture() +def mean_std_summary_network(): + from tests.utils import MeanStdSummaryNetwork + + return MeanStdSummaryNetwork() + + +@pytest.fixture(params=["continuous_approximator", "point_approximator", "model_comparison_approximator"]) +def approximator_with_summaries(request): + from bayesflow.adapters import Adapter + + adapter = Adapter() + match request.param: + case "continuous_approximator": + from bayesflow.approximators import ContinuousApproximator + + return ContinuousApproximator(adapter=adapter, inference_network=None, summary_network=None) + case "point_approximator": + from bayesflow.approximators import PointApproximator + + return PointApproximator(adapter=adapter, inference_network=None, summary_network=None) + case "model_comparison_approximator": + from bayesflow.approximators import ModelComparisonApproximator + + return ModelComparisonApproximator( + num_models=2, classifier_network=None, adapter=adapter, summary_network=None + ) + case _: + raise ValueError("Invalid param for approximator class.") diff --git a/tests/test_approximators/test_summaries.py b/tests/test_approximators/test_summaries.py new file mode 100644 index 000000000..7962ddaab --- /dev/null +++ b/tests/test_approximators/test_summaries.py @@ -0,0 +1,23 @@ +import pytest +from tests.utils import assert_allclose +import keras + + +def test_valid_summaries(approximator_with_summaries, mean_std_summary_network, monkeypatch): + monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network) + summaries = approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))}) + assert_allclose(summaries, keras.ops.stack([keras.ops.ones((2,)), keras.ops.zeros((2,))], axis=-1)) + + +def test_no_summary_network(approximator_with_summaries, monkeypatch): + monkeypatch.setattr(approximator_with_summaries, "summary_network", None) + + with pytest.raises(ValueError): + approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))}) + + +def test_no_summary_variables(approximator_with_summaries, mean_std_summary_network, monkeypatch): + monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network) + + with pytest.raises(ValueError): + approximator_with_summaries.summaries({}) diff --git a/tests/test_diagnostics/conftest.py b/tests/test_diagnostics/conftest.py index 8e77d6729..dc859d2d4 100644 --- a/tests/test_diagnostics/conftest.py +++ b/tests/test_diagnostics/conftest.py @@ -78,3 +78,17 @@ def history(): } return h + + +@pytest.fixture() +def adapter(): + from bayesflow.adapters import Adapter + + return Adapter.create_default("parameters").rename("observables", "summary_variables") + + +@pytest.fixture() +def summary_network(): + from tests.utils import MeanStdSummaryNetwork + + return MeanStdSummaryNetwork() diff --git a/tests/test_diagnostics/test_diagnostics_metrics.py b/tests/test_diagnostics/test_diagnostics_metrics.py index 4fb0945b3..3a2c711bc 100644 --- a/tests/test_diagnostics/test_diagnostics_metrics.py +++ b/tests/test_diagnostics/test_diagnostics_metrics.py @@ -1,6 +1,9 @@ -import bayesflow as bf +import numpy as np +import keras import pytest +import bayesflow as bf + def num_variables(x: dict): return sum(arr.shape[-1] for arr in x.values()) @@ -79,3 +82,288 @@ def test_expected_calibration_error(pred_models, true_models, model_names): with pytest.raises(Exception): out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose) + + +def test_bootstrap_comparison_shapes(): + """Test the bootstrap_comparison output shapes.""" + observed_samples = np.random.rand(10, 5) + reference_samples = np.random.rand(100, 5) + num_null_samples = 50 + + distance_observed, distance_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + assert isinstance(distance_observed, float) + assert isinstance(distance_null, np.ndarray) + assert distance_null.shape == (num_null_samples,) + + +def test_bootstrap_comparison_same_distribution(): + """Test bootstrap_comparison on same distributions.""" + observed_samples = np.random.normal(loc=0.5, scale=0.1, size=(10, 5)) + reference_samples = observed_samples.copy() + num_null_samples = 5 + + distance_observed, distance_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + assert distance_observed <= np.quantile(distance_null, 0.99) + + +def test_bootstrap_comparison_different_distributions(): + """Test bootstrap_comparison on different distributions.""" + observed_samples = np.random.normal(loc=-5, scale=0.1, size=(10, 5)) + reference_samples = np.random.normal(loc=5, scale=0.1, size=(100, 5)) + num_null_samples = 50 + + distance_observed, distance_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + assert distance_observed >= np.quantile(distance_null, 0.68) + + +def test_bootstrap_comparison_mismatched_shapes(): + """Test bootstrap_comparison raises ValueError for mismatched shapes.""" + observed_samples = np.random.rand(10, 5) + reference_samples = np.random.rand(20, 4) + num_null_samples = 10 + + with pytest.raises(ValueError): + bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + +def test_bootstrap_comparison_num_observed_exceeds_num_reference(): + """Test bootstrap_comparison raises ValueError when number of observed samples exceeds the number of reference + samples.""" + observed_samples = np.random.rand(100, 5) + reference_samples = np.random.rand(20, 5) + num_null_samples = 50 + + with pytest.raises(ValueError): + bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + +def test_mmd_comparison_from_summaries_shapes(): + """Test the mmd_comparison_from_summaries output shapes.""" + observed_summaries = np.random.rand(10, 5) + reference_summaries = np.random.rand(100, 5) + num_null_samples = 50 + + mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_summaries, + reference_summaries, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + num_null_samples=num_null_samples, + ) + + assert isinstance(mmd_observed, float) + assert isinstance(mmd_null, np.ndarray) + assert mmd_null.shape == (num_null_samples,) + + +def test_mmd_comparison_from_summaries_positive(): + """Test MMD output values of mmd_comparison_from_summaries are positive.""" + observed_summaries = np.random.rand(10, 5) + reference_summaries = np.random.rand(100, 5) + num_null_samples = 50 + + mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_summaries, + reference_summaries, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + num_null_samples=num_null_samples, + ) + + assert mmd_observed >= 0 + assert np.all(mmd_null >= 0) + + +def test_mmd_comparison_from_summaries_same_distribution(): + """Test mmd_comparison_from_summaries on same distributions.""" + observed_summaries = np.random.rand(10, 5) + reference_summaries = observed_summaries.copy() + num_null_samples = 5 + + mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_summaries, + reference_summaries, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + num_null_samples=num_null_samples, + ) + + assert mmd_observed <= np.quantile(mmd_null, 0.99) + + +def test_mmd_comparison_from_summaries_different_distributions(): + """Test mmd_comparison_from_summaries on different distributions.""" + observed_summaries = np.random.rand(10, 5) + reference_summaries = np.random.normal(loc=0.5, scale=0.1, size=(100, 5)) + num_null_samples = 50 + + mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_summaries, + reference_summaries, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + num_null_samples=num_null_samples, + ) + + assert mmd_observed >= np.quantile(mmd_null, 0.68) + + +def test_mmd_comparison_shapes(summary_network, adapter): + """Test the mmd_comparison output shapes.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.rand(100, 5)) + num_null_samples = 50 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=summary_network, + ) + + mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + assert isinstance(mmd_observed, float) + assert isinstance(mmd_null, np.ndarray) + assert mmd_null.shape == (num_null_samples,) + + +def test_mmd_comparison_positive(summary_network, adapter): + """Test MMD output values of mmd_comparison are positive.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.rand(100, 5)) + num_null_samples = 50 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=summary_network, + ) + + mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + assert mmd_observed >= 0 + assert np.all(mmd_null >= 0) + + +def test_mmd_comparison_same_distribution(summary_network, adapter): + """Test mmd_comparison on same distributions.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = observed_data + num_null_samples = 5 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=summary_network, + ) + + mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + assert mmd_observed <= np.quantile(mmd_null, 0.99) + + +def test_mmd_comparison_different_distributions(summary_network, adapter): + """Test mmd_comparison on different distributions.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.normal(loc=0.5, scale=0.1, size=(100, 5))) + num_null_samples = 50 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=summary_network, + ) + + mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + assert mmd_observed >= np.quantile(mmd_null, 0.68) + + +def test_mmd_comparison_no_summary_network(adapter): + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.rand(100, 5)) + num_null_samples = 50 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=None, + ) + + with pytest.raises(ValueError): + bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + +def test_mmd_comparison_approximator_incorrect_instance(): + """Test mmd_comparison raises ValueError for incorrect approximator instance.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.rand(100, 5)) + num_null_samples = 50 + + class IncorrectApproximator: + pass + + mock_approximator = IncorrectApproximator() + + with pytest.raises(ValueError): + bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index f36b02bbd..9c2affc22 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -2,4 +2,5 @@ from .callbacks import * from .check_combinations import * from .jupyter import * +from .networks import * from .ops import * diff --git a/tests/utils/networks.py b/tests/utils/networks.py new file mode 100644 index 000000000..cf35e1463 --- /dev/null +++ b/tests/utils/networks.py @@ -0,0 +1,8 @@ +from bayesflow.networks import SummaryNetwork +import keras + + +class MeanStdSummaryNetwork(SummaryNetwork): + def call(self, x): + summary_outputs = keras.ops.stack([keras.ops.mean(x, axis=-1), keras.ops.std(x, axis=-1)], axis=-1) + return summary_outputs From 3a69644a1e95d17268fb11180f90d47fd415479c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20Olischl=C3=A4ger?= <106988117+han-ol@users.noreply.github.com> Date: Sat, 3 May 2025 21:54:52 +0200 Subject: [PATCH 41/46] Add projectors to DeepSet (#453) * v2.0.2 (#447) * [no ci] notebook tests: increase timeout, fix platform/backend dependent code Torch is very slow, so I had to increase the timeout accordingly. * Enable use of summary networks with functional API again (#434) * summary networks: add tests for using functional API * fix build functions for use with functional API * [no ci] docs: add GitHub and Discourse links, reorder navbar * [no ci] docs: acknowledge scikit-learn website * [no ci] docs: capitalize navigation headings * More tests (#437) * fix docs of coupling flow * add additional tests * Automatically run slow tests when main is involved. (#438) In addition, this PR limits the slow test to Windows and Python 3.10. The choices are somewhat arbitrary, my thought was to test the setup not covered as much through use by the devs. * Update dispatch * Update dispatching distributions * Improve workflow tests with multiple summary nets / approximators * Fix zombie find_distribution import * Add readme entry [no ci] * Update README: NumFOCUS affiliation, awesome-abi list (#445) * fix is_symbolic_tensor * remove multiple batch sizes, remove multiple python version tests, remove update-workflows branch from workflow style tests, add __init__ and conftest to test_point_approximators (#443) * implement compile_from_config and get_compile_config (#442) * implement compile_from_config and get_compile_config * add optimizer build to compile_from_config * Fix Optimal Transport for Compiled Contexts (#446) * remove the is_symbolic_tensor check because this would otherwise skip the whole function for compiled contexts * skip pyabc test * fix sinkhorn and log_sinkhorn message formatting for jax by making the warning message worse * update dispatch tests for more coverage * Update issue templates (#448) * Hotfix Version 2.0.1 (#431) * fix optimal transport config (#429) * run linter * [skip-ci] bump version to 2.0.1 * Update issue templates * Robustify kwargs passing inference networks, add class variables * fix convergence method to debug for non-log sinkhorn * Bump optimal transport default to False * use logging.info for backend selection instead of logging.debug * fix model comparison approximator * improve docs and type hints * improve One-Sample T-Test Notebook: - use torch as default backend - reduce range of N so users of jax won't be stuck with a slow notebook - use BayesFlow built-in MLP instead of keras.Sequential solution - general code cleanup * remove backend print * [skip ci] turn all single-quoted strings into double-quoted strings * turn all single-quoted strings into double-quoted strings amend to trigger workflow --------- Co-authored-by: Valentin Pratz Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: stefanradev93 Co-authored-by: Marvin Schmitt <35921281+marvinschmitt@users.noreply.github.com> * drafting feature * Initialize projectors for invariant and equivariant DeepSet layers * implement requested changes and improve activation --------- Co-authored-by: Lars Co-authored-by: Valentin Pratz Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: stefanradev93 Co-authored-by: Marvin Schmitt <35921281+marvinschmitt@users.noreply.github.com> --- bayesflow/networks/deep_set/deep_set.py | 4 ++-- bayesflow/networks/deep_set/equivariant_layer.py | 6 +++++- bayesflow/networks/deep_set/invariant_layer.py | 4 ++++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py index 633a1508b..5fd9cc0b0 100644 --- a/bayesflow/networks/deep_set/deep_set.py +++ b/bayesflow/networks/deep_set/deep_set.py @@ -30,7 +30,7 @@ def __init__( mlp_widths_invariant_inner: Sequence[int] = (64, 64), mlp_widths_invariant_outer: Sequence[int] = (64, 64), mlp_widths_invariant_last: Sequence[int] = (64, 64), - activation: str = "gelu", + activation: str = "silu", kernel_initializer: str = "he_normal", dropout: int | float | None = 0.05, spectral_normalization: bool = False, @@ -72,7 +72,7 @@ def __init__( mlp_widths_invariant_last : Sequence[int], optional Widths of the MLP layers in the final invariant transformation. Default is (64, 64). activation : str, optional - Activation function used throughout the network, such as "gelu". Default is "gelu". + Activation function used throughout the network, such as "gelu". Default is "silu". kernel_initializer : str, optional Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal". dropout : int, float, or None, optional diff --git a/bayesflow/networks/deep_set/equivariant_layer.py b/bayesflow/networks/deep_set/equivariant_layer.py index 81bd62f58..0e6587d26 100644 --- a/bayesflow/networks/deep_set/equivariant_layer.py +++ b/bayesflow/networks/deep_set/equivariant_layer.py @@ -94,6 +94,7 @@ def __init__( kernel_initializer=kernel_initializer, spectral_normalization=spectral_normalization, ) + self.out_fc_projector = keras.layers.Dense(mlp_widths_equivariant[-1], kernel_initializer=kernel_initializer) self.layer_norm = layers.LayerNormalization() if layer_norm else None @@ -137,7 +138,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: output_set = ops.concatenate([input_set, invariant_summary], axis=-1) # Pass through final equivariant transform + residual - output_set = input_set + self.equivariant_fc(output_set, training=training) + out_fc = self.equivariant_fc(output_set, training=training) + out_projected = self.out_fc_projector(out_fc) + output_set = input_set + out_projected + if self.layer_norm is not None: output_set = self.layer_norm(output_set, training=training) diff --git a/bayesflow/networks/deep_set/invariant_layer.py b/bayesflow/networks/deep_set/invariant_layer.py index d1b6a26f9..fbf74fb2a 100644 --- a/bayesflow/networks/deep_set/invariant_layer.py +++ b/bayesflow/networks/deep_set/invariant_layer.py @@ -74,6 +74,7 @@ def __init__( kernel_initializer=kernel_initializer, spectral_normalization=spectral_normalization, ) + self.inner_projector = keras.layers.Dense(mlp_widths_inner[-1], kernel_initializer=kernel_initializer) self.outer_fc = MLP( mlp_widths_outer, @@ -82,6 +83,7 @@ def __init__( kernel_initializer=kernel_initializer, spectral_normalization=spectral_normalization, ) + self.outer_projector = keras.layers.Dense(mlp_widths_outer[-1], kernel_initializer=kernel_initializer) # Pooling function as keras layer for sum decomposition: inner( pooling( inner(set) ) ) if pooling_kwargs is None: @@ -106,8 +108,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: """ set_summary = self.inner_fc(input_set, training=training) + set_summary = self.inner_projector(set_summary) set_summary = self.pooling_layer(set_summary, training=training) set_summary = self.outer_fc(set_summary, training=training) + set_summary = self.outer_projector(set_summary) return set_summary @sanitize_input_shape From ec938ed40b61ea8d95f76e55e27236893f0b5109 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 4 May 2025 18:50:57 +0000 Subject: [PATCH 42/46] [no ci] bf 1.1 to 2.0 notebook: minor additions/edits --- examples/From_BayesFlow_1.1_to_2.0.ipynb | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/From_BayesFlow_1.1_to_2.0.ipynb b/examples/From_BayesFlow_1.1_to_2.0.ipynb index fa0e13876..b8090fa79 100644 --- a/examples/From_BayesFlow_1.1_to_2.0.ipynb +++ b/examples/From_BayesFlow_1.1_to_2.0.ipynb @@ -15,9 +15,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Long-time users of BayesFlow will notice that with the update to version 2.0 many things have changed. This short guide aims to clarify some of those changes. Users familiar with the previous Quickstart guide will notice that this notebook follows a similar structure, but assumes that users are already familiar with BayesFlow. We omit many of the the mathematical explanations in favor of demonstrating the differences in workflow. For a more detailed explanation of the BayesFlow framework, users should read, for example, the Starter Notebook on Bayesian Linear Regression.\n", + "Long-time users of BayesFlow will notice that with the update to version 2.0 many things have changed. This short guide aims to clarify some of those changes. Users familiar with the previous Quickstart guide will notice that this notebook follows a similar structure, but assumes that users are already familiar with BayesFlow. We omit many of the mathematical explanations in favor of demonstrating the differences in workflow. For a more detailed explanation of the BayesFlow framework, users should read, for example, the Starter Notebook on Bayesian Linear Regression.\n", "\n", - "Additionally to avoid confusion, similarly named objects from _BayesFlow v1.1_ will have `1.1` after their name, whereas those from _BayesFlow v2.0_ will not. Finally, a short table with a summary of the function call changes is provided at the end of the guide. " + "Additionally, to avoid confusion, similarly named objects from _BayesFlow v1.1_ will have `1.1` after their name, whereas those from _BayesFlow v2.0_ will not. Finally, a short table with a summary of the function call changes is provided at the end of the guide. " ] }, { @@ -26,7 +26,7 @@ "source": [ "## Keras Framework\n", "\n", - "BayesFlow 2.0 looks quite different from BayesFlow 1.1 because the library was refactored to replace the old backend, TensorFlow, with the new [Keras](https://keras.io) API. Users can now choose their preferred backend among the machine learning frameworks `TensorFlow`, `JAX` and `PyTorch`." + "BayesFlow 2.0 looks quite different from BayesFlow 1.1 because the library was refactored to replace the old backend, TensorFlow, with the new [Keras3](https://keras.io) API. Users can now choose their preferred backend among the machine learning frameworks [TensorFlow](https://github.com/tensorflow/tensorflow), [JAX](https://github.com/google/jax) and [PyTorch](https://github.com/pytorch/pytorch)." ] }, { @@ -52,7 +52,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In general, BayesFlow 2.0 relies much more on dictionaries since parameters are now named by convention. Many objects now expect a dictionary, and parameters and data are returned as dictionaries as well. " + "In general, BayesFlow 2.0 relies much more on dictionaries since parameters and data are now named by convention. Many objects now expect a dictionary, and parameters and data are returned as dictionaries as well. " ] }, { @@ -67,7 +67,7 @@ "\n", "Previously users would define a prior function and pass it to a `Prior1.1` object to sample prior values. The likelihood would also be specified via a function and passed to a `Simulator1.1` wrapper to produce observations for given parameter values. These were then combined in the `GenerativeModel1.1`. \n", "\n", - "In 2.0 we no longer make use of the `Prior1.1`, `Simulator1.1` or `GenerativeModel1.1` objects. Instead, the `Simulator` class comprises the whole functionality, taking the role of the `GenerativeModel1.1`. It directly produces joint samples from prior and likelihood, without creating separate `Prior1.1` and `Simulator1.1` objects first. The `bf.simulator.make_simulator` offers a convenient wrapper to create the appropriate simulator for different settings." + "In 2.0 we no longer make use of the `Prior1.1`, `Simulator1.1` or `GenerativeModel1.1` objects. Instead, the `Simulator` class comprises the whole functionality, taking the role of the `GenerativeModel1.1`. It directly produces joint samples from prior and likelihood, without creating separate `Prior1.1` and `Simulator1.1` objects first. The `bf.make_simulator` offers a convenient wrapper to create the appropriate simulator for different settings." ] }, { @@ -217,7 +217,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "For the inference network there are now several implemented architectures for users to choose from. They are `FlowMatching`, `ConsistencyModel`, and `CouplingFlow`. For this demonstration we use `FlowMatching`, for explanations on the different models please refer to the other examples and the API documentation. " + "The previous version only featured the `InvertibleNetwork1.1` class as inference network. As the generative modeling landscape has evolved since then, there are now several implemented architectures for users to choose from. `CouplingFlow` features an architecture similar to `InvertibleNetwork1.1`. `FlowMatching` is a continuous flow architecture, which can express more complex distributions, at the cost of higher inference time. For this demonstration we use `FlowMatching`. For explanations on the different models please refer to the other examples and the API documentation. " ] }, { @@ -377,10 +377,11 @@ "| BayesFlow v1.1 | BayesFlow v2.0 usage |\n", "| :--------| :---------| \n", "| `Prior`, `Simulator` | Defunct and no longer standalone objects but incorporated into `bf.simulators.Simulator` | \n", - "|`GenerativeModel` | Defunct with it's functionality having been taken over by `bf.simulators.make_simulator` | \n", + "| `GenerativeModel` | Defunct with it's functionality having been taken over by `bf.make_simulator` | \n", "| `training.configurator` | Functionality taken over by `bf.adapters.Adapter` | \n", "|`Trainer` | Functionality taken over by `fit` method of `bf.approximators.Approximator` | \n", - "| `AmortizedPosterior`, `AmortizedLikelihood` | Functionality taken over by `ContinuousApproximator` | " + "| `AmortizedPosterior`, `AmortizedLikelihood` | Functionality taken over by `ContinuousApproximator` |\n", + "| `InvertibleNetwork` | Functionality taken over by `CouplingFlow`, but also other networks that are subclasses of `InferenceNetwork`, e.g. `FlowMatching` |" ] } ], From ddd4ea8e3cd53ec38f0cb65b77a16489bdeeaca6 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Mon, 5 May 2025 07:50:51 +0000 Subject: [PATCH 43/46] [no ci] add minor details to FAQ --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5d2dc61f0..7c8d08d2f 100644 --- a/README.md +++ b/README.md @@ -245,8 +245,9 @@ Depending on your needs, you might not want to upgrade yet if one of the followi with the new version. Loading models from version 1.x in version 2.0+ is not supported. - You require a feature that was not ported to BayesFlow 2.0+ yet. To our knowledge, this applies to: - * Two-level/Hierarchical models: `TwoLevelGenerativeModel`, `TwoLevelPrior`. - * Sensitivity analysis: functionality from the `bayesflow.sensitivity` module. + * Two-level/Hierarchical models (planned for version 2.1): `TwoLevelGenerativeModel`, `TwoLevelPrior`. + * Sensitivity analysis (partially discontinued): functionality from the `bayesflow.sensitivity` module. This is still + possible, but we do no longer offer a special module for it. We plan to add a tutorial on this, see [#455](https://github.com/bayesflow-org/bayesflow/issues/455). * MCMC (discontinued): The `bayesflow.mcmc` module. We are considering other options to enable the use of BayesFlow in an MCMC setting. * Networks: `EvidentialNetwork`. From e8d2f2cdd0e3ed34248044ffb6f44888da1744a3 Mon Sep 17 00:00:00 2001 From: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Date: Mon, 5 May 2025 15:24:10 +0200 Subject: [PATCH 44/46] serialization: apply new scheme for package (breaking change) (#457) * serialization: apply new scheme for `package` (breaking) - introduces new policy for consistent naming for serilization (see #451 for a discussion): standard is the path of a module a class resides in, trucated at depth to. So for all class in bayesflow.networks, we set package="bayesflow.networks", even if the live in the bayesflow.networks.mlp submodule. - The `serializable` decorator checks this and errors if this is not followed. The check can be disabled for certain cases (e.g., classes in the experimental module, that might eventually live somewhere else). - After this commit, previously saved models will not be loadable. As we introduced a bug regarding this anyway (#451), we will accept this and should inform users about it. - usage of direct calls to `keras.saving.register_keras_serializable` were replaced with our custom decorator. * update serilization policy in dev docs * README: add not regarding breaking changes until 2.1 release * standardize use of serializable decorator * [no ci] change (de)serialize to new pipeline in transform * serialization check: exempt classes not in bayesflow module This should ensure that users that try to use our decorator with external classes do not encounter the error. Possible edge case: they also name their module "bayesflow". --------- Co-authored-by: LarsKue --- README.md | 6 +++ bayesflow/adapters/adapter.py | 2 +- bayesflow/adapters/transforms/as_set.py | 2 +- .../adapters/transforms/as_time_series.py | 2 +- bayesflow/adapters/transforms/broadcast.py | 2 +- bayesflow/adapters/transforms/concatenate.py | 2 +- bayesflow/adapters/transforms/constrain.py | 2 +- .../adapters/transforms/convert_dtype.py | 2 +- bayesflow/adapters/transforms/drop.py | 2 +- .../transforms/elementwise_transform.py | 2 +- bayesflow/adapters/transforms/expand_dims.py | 2 +- .../adapters/transforms/filter_transform.py | 2 +- bayesflow/adapters/transforms/keep.py | 2 +- bayesflow/adapters/transforms/log.py | 2 +- .../adapters/transforms/map_transform.py | 2 +- .../adapters/transforms/numpy_transform.py | 2 +- bayesflow/adapters/transforms/one_hot.py | 2 +- bayesflow/adapters/transforms/rename.py | 2 +- bayesflow/adapters/transforms/scale.py | 2 +- .../serializable_custom_transform.py | 7 ++-- bayesflow/adapters/transforms/shift.py | 2 +- bayesflow/adapters/transforms/split.py | 2 +- bayesflow/adapters/transforms/sqrt.py | 2 +- bayesflow/adapters/transforms/standardize.py | 2 +- bayesflow/adapters/transforms/to_array.py | 2 +- bayesflow/adapters/transforms/to_dict.py | 2 +- bayesflow/adapters/transforms/transform.py | 2 +- .../approximators/continuous_approximator.py | 2 +- .../model_comparison_approximator.py | 2 +- bayesflow/approximators/point_approximator.py | 2 +- bayesflow/distributions/diagonal_normal.py | 2 +- bayesflow/distributions/diagonal_student_t.py | 2 +- bayesflow/distributions/distribution.py | 2 +- bayesflow/distributions/mixture.py | 2 +- bayesflow/experimental/cif/cif.py | 5 ++- .../experimental/cif/conditional_gaussian.py | 5 ++- .../continuous_time_consistency_model.py | 3 +- .../free_form_flow/free_form_flow.py | 3 +- bayesflow/experimental/resnet/dense_resnet.py | 3 +- bayesflow/experimental/resnet/double_conv.py | 3 +- .../experimental/resnet/double_linear.py | 3 +- bayesflow/experimental/resnet/resnet.py | 3 +- bayesflow/links/ordered.py | 4 +- bayesflow/links/ordered_quantiles.py | 4 +- bayesflow/links/positive_definite.py | 5 +-- bayesflow/metrics/maximum_mean_discrepancy.py | 2 +- bayesflow/metrics/root_mean_squard_error.py | 2 +- .../consistency_models/consistency_model.py | 2 +- bayesflow/networks/coupling_flow/actnorm.py | 2 +- .../networks/coupling_flow/coupling_flow.py | 2 +- .../coupling_flow/couplings/dual_coupling.py | 2 +- .../couplings/single_coupling.py | 2 +- .../permutations/fixed_permutation.py | 2 +- .../coupling_flow/permutations/orthogonal.py | 2 +- .../coupling_flow/permutations/random.py | 2 +- .../coupling_flow/permutations/swap.py | 2 +- .../transforms/affine_transform.py | 2 +- .../transforms/spline_transform.py | 2 +- bayesflow/networks/deep_set/deep_set.py | 2 +- .../networks/deep_set/equivariant_layer.py | 2 +- .../networks/deep_set/invariant_layer.py | 2 +- .../networks/embeddings/fourier_embedding.py | 2 +- .../embeddings/recurrent_embedding.py | 2 +- bayesflow/networks/embeddings/time2vec.py | 2 +- .../networks/flow_matching/flow_matching.py | 2 +- bayesflow/networks/mlp/mlp.py | 2 +- bayesflow/networks/point_inference_network.py | 2 +- bayesflow/networks/residual/residual.py | 2 +- .../time_series_network/skip_recurrent.py | 2 +- .../time_series_network.py | 2 +- .../transformers/fusion_transformer.py | 2 +- bayesflow/networks/transformers/isab.py | 2 +- bayesflow/networks/transformers/mab.py | 2 +- bayesflow/networks/transformers/pma.py | 2 +- bayesflow/networks/transformers/sab.py | 2 +- .../networks/transformers/set_transformer.py | 2 +- .../transformers/time_series_transformer.py | 2 +- bayesflow/scores/mean_score.py | 5 +-- bayesflow/scores/median_score.py | 5 +-- bayesflow/scores/multivariate_normal_score.py | 4 +- bayesflow/scores/normed_difference_score.py | 4 +- .../scores/parametric_distribution_score.py | 5 +-- bayesflow/scores/quantile_score.py | 4 +- bayesflow/scores/scoring_rule.py | 2 +- bayesflow/utils/serialization.py | 39 ++++++++++++++----- bayesflow/wrappers/mamba/mamba.py | 2 +- bayesflow/wrappers/mamba/mamba_block.py | 2 +- docsrc/source/development/serialization.md | 11 +++++- .../test_utils/test_serialize_deserialize.py | 4 +- tests/test_workflows/conftest.py | 2 +- 90 files changed, 155 insertions(+), 116 deletions(-) diff --git a/README.md b/README.md index 7c8d08d2f..9b5ef3d47 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,12 @@ It provides users and researchers with: BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference fueled by continuous progress in generative AI and Bayesian inference. +> [!IMPORTANT] +> As the 2.0 version introduced many new features, we still have to make breaking changes from time to time. +> This especially concerns **saving and loading** of models. We aim to stabilize this from the 2.1 release onwards. +> Until then, consider pinning your BayesFlow 2.0 installation to an exact version, or re-training after an update +> for less costly models. + ## Important Note for Existing Users You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library. diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index ab6800d8a..a17a59d81 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -29,7 +29,7 @@ from .transforms.filter_transform import Predicate -@serializable +@serializable("bayesflow.adapters") class Adapter(MutableSequence[Transform]): """ Defines an adapter to apply various transforms to data. diff --git a/bayesflow/adapters/transforms/as_set.py b/bayesflow/adapters/transforms/as_set.py index f4d5bdfc5..903536bc4 100644 --- a/bayesflow/adapters/transforms/as_set.py +++ b/bayesflow/adapters/transforms/as_set.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class AsSet(ElementwiseTransform): """The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets. diff --git a/bayesflow/adapters/transforms/as_time_series.py b/bayesflow/adapters/transforms/as_time_series.py index 3f4d2a2c5..d7791352c 100644 --- a/bayesflow/adapters/transforms/as_time_series.py +++ b/bayesflow/adapters/transforms/as_time_series.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class AsTimeSeries(ElementwiseTransform): """The `.as_time_series` transform can be used to indicate that variables shall be treated as time series. diff --git a/bayesflow/adapters/transforms/broadcast.py b/bayesflow/adapters/transforms/broadcast.py index 8667ec0c7..646e1f72e 100644 --- a/bayesflow/adapters/transforms/broadcast.py +++ b/bayesflow/adapters/transforms/broadcast.py @@ -6,7 +6,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Broadcast(Transform): """ Broadcasts arrays or scalars to the shape of a given other array. diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index 91ea9178b..ac3700616 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -7,7 +7,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Concatenate(Transform): """Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network. diff --git a/bayesflow/adapters/transforms/constrain.py b/bayesflow/adapters/transforms/constrain.py index a4ca0be25..d01211dfc 100644 --- a/bayesflow/adapters/transforms/constrain.py +++ b/bayesflow/adapters/transforms/constrain.py @@ -11,7 +11,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Constrain(ElementwiseTransform): """ Constrains neural network predictions of a data variable to specified bounds. diff --git a/bayesflow/adapters/transforms/convert_dtype.py b/bayesflow/adapters/transforms/convert_dtype.py index e68815269..8cd21b4cc 100644 --- a/bayesflow/adapters/transforms/convert_dtype.py +++ b/bayesflow/adapters/transforms/convert_dtype.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class ConvertDType(ElementwiseTransform): """ Default transform used to convert all floats from float64 to float32 to be in line with keras framework. diff --git a/bayesflow/adapters/transforms/drop.py b/bayesflow/adapters/transforms/drop.py index 91dcd6a28..5073027e6 100644 --- a/bayesflow/adapters/transforms/drop.py +++ b/bayesflow/adapters/transforms/drop.py @@ -5,7 +5,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Drop(Transform): """ Transform to drop variables from further calculation. diff --git a/bayesflow/adapters/transforms/elementwise_transform.py b/bayesflow/adapters/transforms/elementwise_transform.py index 7d603d517..020301749 100644 --- a/bayesflow/adapters/transforms/elementwise_transform.py +++ b/bayesflow/adapters/transforms/elementwise_transform.py @@ -3,7 +3,7 @@ from bayesflow.utils.serialization import serializable, deserialize -@serializable +@serializable("bayesflow.adapters") class ElementwiseTransform: """Base class on which other transforms are based""" diff --git a/bayesflow/adapters/transforms/expand_dims.py b/bayesflow/adapters/transforms/expand_dims.py index e44d133b8..0f4151d37 100644 --- a/bayesflow/adapters/transforms/expand_dims.py +++ b/bayesflow/adapters/transforms/expand_dims.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class ExpandDims(ElementwiseTransform): """ Expand the shape of an array. diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py index 7eccf370b..4dc2c8008 100644 --- a/bayesflow/adapters/transforms/filter_transform.py +++ b/bayesflow/adapters/transforms/filter_transform.py @@ -14,7 +14,7 @@ def __call__(self, key: str, value: np.ndarray, inverse: bool) -> bool: raise NotImplementedError -@serializable +@serializable("bayesflow.adapters") class FilterTransform(Transform): """ Implements a transform that applies a different transform on a subset of the data. diff --git a/bayesflow/adapters/transforms/keep.py b/bayesflow/adapters/transforms/keep.py index 56f395166..c69d01ca3 100644 --- a/bayesflow/adapters/transforms/keep.py +++ b/bayesflow/adapters/transforms/keep.py @@ -5,7 +5,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Keep(Transform): """ Name the data parameters that should be kept for futher calculation. diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index d5f559b4f..a42c43ef0 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Log(ElementwiseTransform): """Log transforms a variable. diff --git a/bayesflow/adapters/transforms/map_transform.py b/bayesflow/adapters/transforms/map_transform.py index 5da8292af..15c5c945d 100644 --- a/bayesflow/adapters/transforms/map_transform.py +++ b/bayesflow/adapters/transforms/map_transform.py @@ -6,7 +6,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class MapTransform(Transform): """ Implements a transform that applies a set of elementwise transforms diff --git a/bayesflow/adapters/transforms/numpy_transform.py b/bayesflow/adapters/transforms/numpy_transform.py index 29d25dc67..a19216dd2 100644 --- a/bayesflow/adapters/transforms/numpy_transform.py +++ b/bayesflow/adapters/transforms/numpy_transform.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class NumpyTransform(ElementwiseTransform): """ A class to apply element-wise transformations using plain NumPy functions. diff --git a/bayesflow/adapters/transforms/one_hot.py b/bayesflow/adapters/transforms/one_hot.py index e097a28f9..bbed8bf5d 100644 --- a/bayesflow/adapters/transforms/one_hot.py +++ b/bayesflow/adapters/transforms/one_hot.py @@ -6,7 +6,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class OneHot(ElementwiseTransform): """ Changes data to be one-hot encoded. diff --git a/bayesflow/adapters/transforms/rename.py b/bayesflow/adapters/transforms/rename.py index 746ef5a80..bec3388b0 100644 --- a/bayesflow/adapters/transforms/rename.py +++ b/bayesflow/adapters/transforms/rename.py @@ -3,7 +3,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Rename(Transform): """ Transform to rename keys in data dictionary. Useful to rename variables to match those required by diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py index 96b2ff927..d7c1aa2a7 100644 --- a/bayesflow/adapters/transforms/scale.py +++ b/bayesflow/adapters/transforms/scale.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Scale(ElementwiseTransform): def __init__(self, scale: np.typing.ArrayLike): self.scale = np.array(scale) diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py index 75d588afd..248fc0ec5 100644 --- a/bayesflow/adapters/transforms/serializable_custom_transform.py +++ b/bayesflow/adapters/transforms/serializable_custom_transform.py @@ -1,18 +1,17 @@ from collections.abc import Callable import numpy as np from keras.saving import ( - deserialize_keras_object as deserialize, - register_keras_serializable as serializable, - serialize_keras_object as serialize, get_registered_name, get_registered_object, ) + +from bayesflow.utils.serialization import deserialize, serializable, serialize from .elementwise_transform import ElementwiseTransform from ...utils import filter_kwargs import inspect -@serializable(package="bayesflow.adapters") +@serializable("bayesflow.adapters") class SerializableCustomTransform(ElementwiseTransform): """ Transforms a parameter using a pair of registered serializable forward and inverse functions. diff --git a/bayesflow/adapters/transforms/shift.py b/bayesflow/adapters/transforms/shift.py index b7c9659d2..5923b4e49 100644 --- a/bayesflow/adapters/transforms/shift.py +++ b/bayesflow/adapters/transforms/shift.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Shift(ElementwiseTransform): def __init__(self, shift: np.typing.ArrayLike): self.shift = np.array(shift) diff --git a/bayesflow/adapters/transforms/split.py b/bayesflow/adapters/transforms/split.py index 919db4e08..4c0ae9f65 100644 --- a/bayesflow/adapters/transforms/split.py +++ b/bayesflow/adapters/transforms/split.py @@ -6,7 +6,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class Split(Transform): """This is the effective inverse of the :py:class:`~Concatenate` Transform. diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index 4ef1370dc..bcfe49136 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -5,7 +5,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Sqrt(ElementwiseTransform): """Square-root transform a variable. diff --git a/bayesflow/adapters/transforms/standardize.py b/bayesflow/adapters/transforms/standardize.py index 9699819b9..a1c3c5a3d 100644 --- a/bayesflow/adapters/transforms/standardize.py +++ b/bayesflow/adapters/transforms/standardize.py @@ -7,7 +7,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class Standardize(ElementwiseTransform): """ Transform that when applied standardizes data using typical z-score standardization diff --git a/bayesflow/adapters/transforms/to_array.py b/bayesflow/adapters/transforms/to_array.py index 9d5381ca0..fe1b82f2d 100644 --- a/bayesflow/adapters/transforms/to_array.py +++ b/bayesflow/adapters/transforms/to_array.py @@ -7,7 +7,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable +@serializable("bayesflow.adapters") class ToArray(ElementwiseTransform): """ Checks provided data for any non-arrays and converts them to numpy arrays. diff --git a/bayesflow/adapters/transforms/to_dict.py b/bayesflow/adapters/transforms/to_dict.py index 6babb2a40..cfc4ec00d 100644 --- a/bayesflow/adapters/transforms/to_dict.py +++ b/bayesflow/adapters/transforms/to_dict.py @@ -6,7 +6,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.adapters") class ToDict(Transform): """Convert non-dict batches (e.g., pandas.DataFrame) to dict batches""" diff --git a/bayesflow/adapters/transforms/transform.py b/bayesflow/adapters/transforms/transform.py index ed3058e15..0bc6331bc 100644 --- a/bayesflow/adapters/transforms/transform.py +++ b/bayesflow/adapters/transforms/transform.py @@ -3,7 +3,7 @@ from bayesflow.utils.serialization import serializable, deserialize -@serializable +@serializable("bayesflow.adapters") class Transform: """ Base class on which other transforms are based diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 834521d4b..3e43a8917 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -13,7 +13,7 @@ from .approximator import Approximator -@serializable +@serializable("bayesflow.approximators") class ContinuousApproximator(Approximator): """ Defines a workflow for performing fast posterior or likelihood inference. diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 028e8837a..86af4ee33 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -14,7 +14,7 @@ from .approximator import Approximator -@serializable +@serializable("bayesflow.approximators") class ModelComparisonApproximator(Approximator): """ Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are diff --git a/bayesflow/approximators/point_approximator.py b/bayesflow/approximators/point_approximator.py index 1e407e2a6..b3d90781c 100644 --- a/bayesflow/approximators/point_approximator.py +++ b/bayesflow/approximators/point_approximator.py @@ -11,7 +11,7 @@ from .continuous_approximator import ContinuousApproximator -@serializable +@serializable("bayesflow.approximators") class PointApproximator(ContinuousApproximator): """ A workflow for fast amortized point estimation of a conditional distribution. diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 98a127b1c..f8d93b945 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -12,7 +12,7 @@ from .distribution import Distribution -@serializable +@serializable("bayesflow.distributions") class DiagonalNormal(Distribution): """Implements a backend-agnostic diagonal Gaussian distribution.""" diff --git a/bayesflow/distributions/diagonal_student_t.py b/bayesflow/distributions/diagonal_student_t.py index cd32a67fb..98e3fb7eb 100644 --- a/bayesflow/distributions/diagonal_student_t.py +++ b/bayesflow/distributions/diagonal_student_t.py @@ -13,7 +13,7 @@ from .distribution import Distribution -@serializable +@serializable("bayesflow.distributions") class DiagonalStudentT(Distribution): """Implements a backend-agnostic diagonal Student-t distribution.""" diff --git a/bayesflow/distributions/distribution.py b/bayesflow/distributions/distribution.py index 1d3a83962..3689f0d9f 100644 --- a/bayesflow/distributions/distribution.py +++ b/bayesflow/distributions/distribution.py @@ -5,7 +5,7 @@ from bayesflow.utils.serialization import serializable, deserialize -@serializable +@serializable("bayesflow.distributions") class Distribution(keras.Layer): def __init__(self, **kwargs): super().__init__(**layer_kwargs(kwargs)) diff --git a/bayesflow/distributions/mixture.py b/bayesflow/distributions/mixture.py index d7f6bd758..a7bf2ea27 100644 --- a/bayesflow/distributions/mixture.py +++ b/bayesflow/distributions/mixture.py @@ -11,7 +11,7 @@ from bayesflow.distributions import Distribution -@serializable +@serializable("bayesflow.distributions") class Mixture(Distribution): """Utility class for a backend-agnostic mixture distributions.""" diff --git a/bayesflow/experimental/cif/cif.py b/bayesflow/experimental/cif/cif.py index 8742501b3..bd776f93e 100644 --- a/bayesflow/experimental/cif/cif.py +++ b/bayesflow/experimental/cif/cif.py @@ -1,7 +1,7 @@ import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor +from bayesflow.utils.serialization import serializable from bayesflow.networks.inference_network import InferenceNetwork from bayesflow.networks.coupling_flow import CouplingFlow @@ -9,7 +9,8 @@ from .conditional_gaussian import ConditionalGaussian -@serializable(package="bayesflow.networks") +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class CIF(InferenceNetwork): """Implements a continuously indexed flow (CIF) with a `CouplingFlow` bijection and `ConditionalGaussian` distributions p and q. Improves on diff --git a/bayesflow/experimental/cif/conditional_gaussian.py b/bayesflow/experimental/cif/conditional_gaussian.py index d11f3fb65..ebba47a2e 100644 --- a/bayesflow/experimental/cif/conditional_gaussian.py +++ b/bayesflow/experimental/cif/conditional_gaussian.py @@ -1,13 +1,14 @@ import keras -from keras.saving import register_keras_serializable import numpy as np from bayesflow.networks.mlp import MLP from bayesflow.types import Shape, Tensor from bayesflow.utils import layer_kwargs +from bayesflow.utils.serialization import serializable -@register_keras_serializable(package="bayesflow.networks.cif") +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class ConditionalGaussian(keras.Layer): """Implements a conditional gaussian distribution with neural networks for the means and standard deviations respectively. Bulit in reference to [1]. diff --git a/bayesflow/experimental/continuous_time_consistency_model.py b/bayesflow/experimental/continuous_time_consistency_model.py index 54417cd07..b1c751454 100644 --- a/bayesflow/experimental/continuous_time_consistency_model.py +++ b/bayesflow/experimental/continuous_time_consistency_model.py @@ -22,7 +22,8 @@ from bayesflow.networks.embeddings import FourierEmbedding -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class ContinuousTimeConsistencyModel(InferenceNetwork): """Implements an sCM (simple, stable, and scalable Consistency Model) with continous-time Consistency Training (CT) as described in [1]. diff --git a/bayesflow/experimental/free_form_flow/free_form_flow.py b/bayesflow/experimental/free_form_flow/free_form_flow.py index 61937d56f..12bb97b93 100644 --- a/bayesflow/experimental/free_form_flow/free_form_flow.py +++ b/bayesflow/experimental/free_form_flow/free_form_flow.py @@ -19,7 +19,8 @@ from bayesflow.networks import InferenceNetwork -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class FreeFormFlow(InferenceNetwork): """Implements a dimensionality-preserving Free-form Flow. Incorporates ideas from [1-2]. diff --git a/bayesflow/experimental/resnet/dense_resnet.py b/bayesflow/experimental/resnet/dense_resnet.py index fa380969f..93ff59d3f 100644 --- a/bayesflow/experimental/resnet/dense_resnet.py +++ b/bayesflow/experimental/resnet/dense_resnet.py @@ -8,7 +8,8 @@ from .double_linear import DoubleLinear -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class DenseResNet(keras.Sequential): """ Implements the fully-connected analogue of the ResNet architecture. diff --git a/bayesflow/experimental/resnet/double_conv.py b/bayesflow/experimental/resnet/double_conv.py index c70e37323..a2b6bbc88 100644 --- a/bayesflow/experimental/resnet/double_conv.py +++ b/bayesflow/experimental/resnet/double_conv.py @@ -5,7 +5,8 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class DoubleConv(keras.Sequential): def __init__( self, diff --git a/bayesflow/experimental/resnet/double_linear.py b/bayesflow/experimental/resnet/double_linear.py index e2138c8b0..aae72fa39 100644 --- a/bayesflow/experimental/resnet/double_linear.py +++ b/bayesflow/experimental/resnet/double_linear.py @@ -5,7 +5,8 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class DoubleLinear(keras.Sequential): def __init__( self, diff --git a/bayesflow/experimental/resnet/resnet.py b/bayesflow/experimental/resnet/resnet.py index 07f1f2cda..862e0ac98 100644 --- a/bayesflow/experimental/resnet/resnet.py +++ b/bayesflow/experimental/resnet/resnet.py @@ -8,7 +8,8 @@ from .double_conv import DoubleConv -@serializable +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) class ResNet(keras.Sequential): """ Implements the ResNet architecture. diff --git a/bayesflow/links/ordered.py b/bayesflow/links/ordered.py index 77545b6f8..25caf5350 100644 --- a/bayesflow/links/ordered.py +++ b/bayesflow/links/ordered.py @@ -1,11 +1,11 @@ import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.utils import layer_kwargs from bayesflow.utils.decorators import sanitize_input_shape +from bayesflow.utils.serialization import serializable -@serializable(package="links.ordered") +@serializable("bayesflow.links") class Ordered(keras.Layer): """Activation function to link to a tensor which is monotonously increasing along a specified axis.""" diff --git a/bayesflow/links/ordered_quantiles.py b/bayesflow/links/ordered_quantiles.py index d4f4caba2..81b2c0cc7 100644 --- a/bayesflow/links/ordered_quantiles.py +++ b/bayesflow/links/ordered_quantiles.py @@ -1,14 +1,14 @@ import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.utils import layer_kwargs, logging +from bayesflow.utils.serialization import serializable from collections.abc import Sequence from .ordered import Ordered -@serializable(package="links.ordered_quantiles") +@serializable("bayesflow.links") class OrderedQuantiles(Ordered): """Activation function to link to monotonously increasing quantile estimates.""" diff --git a/bayesflow/links/positive_definite.py b/bayesflow/links/positive_definite.py index 28c937f86..909ac2792 100644 --- a/bayesflow/links/positive_definite.py +++ b/bayesflow/links/positive_definite.py @@ -1,12 +1,11 @@ import keras -from keras.saving import register_keras_serializable as serializable - from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs, fill_triangular_matrix +from bayesflow.utils.serialization import serializable -@serializable(package="bayesflow.links") +@serializable("bayesflow.links") class PositiveDefinite(keras.Layer): """Activation function to link from flat elements of a lower triangular matrix to a positive definite matrix.""" diff --git a/bayesflow/metrics/maximum_mean_discrepancy.py b/bayesflow/metrics/maximum_mean_discrepancy.py index 37af44fd4..de4ee32f1 100644 --- a/bayesflow/metrics/maximum_mean_discrepancy.py +++ b/bayesflow/metrics/maximum_mean_discrepancy.py @@ -6,7 +6,7 @@ from .functional import maximum_mean_discrepancy -@serializable +@serializable("bayesflow.metrics") class MaximumMeanDiscrepancy(keras.Metric): def __init__( self, diff --git a/bayesflow/metrics/root_mean_squard_error.py b/bayesflow/metrics/root_mean_squard_error.py index 97de62e6a..8827095e9 100644 --- a/bayesflow/metrics/root_mean_squard_error.py +++ b/bayesflow/metrics/root_mean_squard_error.py @@ -5,7 +5,7 @@ from .functional import root_mean_squared_error -@serializable +@serializable("bayesflow.metrics") class RootMeanSquaredError(keras.metrics.MeanMetricWrapper): def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs): fn = partial(root_mean_squared_error, **kwargs) diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index b8d4c56ed..8d36c1736 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -12,7 +12,7 @@ from ..inference_network import InferenceNetwork -@serializable +@serializable("bayesflow.networks") class ConsistencyModel(InferenceNetwork): """Implements a Consistency Model with Consistency Training (CT) a described in [1-2]. The adaptations to CT described in [2] were taken into account in our implementation for ABI [3]. diff --git a/bayesflow/networks/coupling_flow/actnorm.py b/bayesflow/networks/coupling_flow/actnorm.py index 5221caea1..81cdc425d 100644 --- a/bayesflow/networks/coupling_flow/actnorm.py +++ b/bayesflow/networks/coupling_flow/actnorm.py @@ -6,7 +6,7 @@ from .invertible_layer import InvertibleLayer -@serializable +@serializable("bayesflow.networks") class ActNorm(InvertibleLayer): """Implements an Activation Normalization (ActNorm) Layer. Activation Normalization is learned invertible normalization, using a scale (s) and a bias (b) vector:: diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index 203962b0f..28954d7d2 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -13,7 +13,7 @@ from ..inference_network import InferenceNetwork -@serializable +@serializable("bayesflow.networks") class CouplingFlow(InferenceNetwork): """Implements a coupling flow as a sequence of dual couplings with permutations and activation normalization. Incorporates ideas from [1-5]. diff --git a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py index 67db6e269..462bc02d6 100644 --- a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py @@ -9,7 +9,7 @@ from ..invertible_layer import InvertibleLayer -@serializable +@serializable("bayesflow.networks") class DualCoupling(InvertibleLayer): def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwargs): super().__init__(**kwargs) diff --git a/bayesflow/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/networks/coupling_flow/couplings/single_coupling.py index d2703cfa7..7bd6aaf3b 100644 --- a/bayesflow/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/networks/coupling_flow/couplings/single_coupling.py @@ -8,7 +8,7 @@ from ..transforms import find_transform -@serializable +@serializable("bayesflow.networks") class SingleCoupling(InvertibleLayer): """ Implements a single coupling layer as a composition of a subnet and a transform. diff --git a/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py b/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py index d93f27ae5..68591a172 100644 --- a/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py +++ b/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py @@ -6,7 +6,7 @@ from ..invertible_layer import InvertibleLayer -@serializable +@serializable("bayesflow.networks") class FixedPermutation(InvertibleLayer): """ Interface class for permutations with no learnable parameters. Child classes should diff --git a/bayesflow/networks/coupling_flow/permutations/orthogonal.py b/bayesflow/networks/coupling_flow/permutations/orthogonal.py index a28fe7965..54bfbf901 100644 --- a/bayesflow/networks/coupling_flow/permutations/orthogonal.py +++ b/bayesflow/networks/coupling_flow/permutations/orthogonal.py @@ -6,7 +6,7 @@ from ..invertible_layer import InvertibleLayer -@serializable +@serializable("bayesflow.networks") class OrthogonalPermutation(InvertibleLayer): """Implements a learnable orthogonal transformation according to [1]. Can be used as an alternative to a fixed ``Permutation`` layer. diff --git a/bayesflow/networks/coupling_flow/permutations/random.py b/bayesflow/networks/coupling_flow/permutations/random.py index 82d7f39ff..522d48c63 100644 --- a/bayesflow/networks/coupling_flow/permutations/random.py +++ b/bayesflow/networks/coupling_flow/permutations/random.py @@ -6,7 +6,7 @@ from .fixed_permutation import FixedPermutation -@serializable +@serializable("bayesflow.networks") class RandomPermutation(FixedPermutation): # noinspection PyMethodOverriding def build(self, xz_shape: Shape, **kwargs) -> None: diff --git a/bayesflow/networks/coupling_flow/permutations/swap.py b/bayesflow/networks/coupling_flow/permutations/swap.py index c5f707a1a..bb7f641b9 100644 --- a/bayesflow/networks/coupling_flow/permutations/swap.py +++ b/bayesflow/networks/coupling_flow/permutations/swap.py @@ -6,7 +6,7 @@ from .fixed_permutation import FixedPermutation -@serializable +@serializable("bayesflow.networks") class Swap(FixedPermutation): def build(self, xz_shape: Shape, **kwargs) -> None: shift = xz_shape[-1] // 2 diff --git a/bayesflow/networks/coupling_flow/transforms/affine_transform.py b/bayesflow/networks/coupling_flow/transforms/affine_transform.py index 9e8c4a9e1..1d66b0bfb 100644 --- a/bayesflow/networks/coupling_flow/transforms/affine_transform.py +++ b/bayesflow/networks/coupling_flow/transforms/affine_transform.py @@ -7,7 +7,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.networks") class AffineTransform(Transform): def __init__(self, clamp: bool = True, **kwargs): super().__init__(**kwargs) diff --git a/bayesflow/networks/coupling_flow/transforms/spline_transform.py b/bayesflow/networks/coupling_flow/transforms/spline_transform.py index d5b0cf4b3..28b9c4415 100644 --- a/bayesflow/networks/coupling_flow/transforms/spline_transform.py +++ b/bayesflow/networks/coupling_flow/transforms/spline_transform.py @@ -10,7 +10,7 @@ from .transform import Transform -@serializable +@serializable("bayesflow.networks") class SplineTransform(Transform): def __init__( self, diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py index 5fd9cc0b0..9c9d0ad23 100644 --- a/bayesflow/networks/deep_set/deep_set.py +++ b/bayesflow/networks/deep_set/deep_set.py @@ -11,7 +11,7 @@ from ..summary_network import SummaryNetwork -@serializable +@serializable("bayesflow.networks") class DeepSet(SummaryNetwork): """Implements a deep set encoder introduced in [1] for learning permutation-invariant representations of set-based data, as generated by exchangeable models. diff --git a/bayesflow/networks/deep_set/equivariant_layer.py b/bayesflow/networks/deep_set/equivariant_layer.py index 0e6587d26..7e35ad9bb 100644 --- a/bayesflow/networks/deep_set/equivariant_layer.py +++ b/bayesflow/networks/deep_set/equivariant_layer.py @@ -13,7 +13,7 @@ from .invariant_layer import InvariantLayer -@serializable +@serializable("bayesflow.networks") class EquivariantLayer(keras.Layer): """Implements an equivariant module performing an equivariant transform. diff --git a/bayesflow/networks/deep_set/invariant_layer.py b/bayesflow/networks/deep_set/invariant_layer.py index fbf74fb2a..2f29c6b8d 100644 --- a/bayesflow/networks/deep_set/invariant_layer.py +++ b/bayesflow/networks/deep_set/invariant_layer.py @@ -11,7 +11,7 @@ from ..mlp import MLP -@serializable +@serializable("bayesflow.networks") class InvariantLayer(keras.Layer): """Implements an invariant module performing a permutation-invariant transform. diff --git a/bayesflow/networks/embeddings/fourier_embedding.py b/bayesflow/networks/embeddings/fourier_embedding.py index 65b5938d7..21924ee60 100644 --- a/bayesflow/networks/embeddings/fourier_embedding.py +++ b/bayesflow/networks/embeddings/fourier_embedding.py @@ -7,7 +7,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class FourierEmbedding(keras.Layer): """Implements a Fourier projection with normally distributed frequencies.""" diff --git a/bayesflow/networks/embeddings/recurrent_embedding.py b/bayesflow/networks/embeddings/recurrent_embedding.py index df7c00f32..3fa82868d 100644 --- a/bayesflow/networks/embeddings/recurrent_embedding.py +++ b/bayesflow/networks/embeddings/recurrent_embedding.py @@ -6,7 +6,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class RecurrentEmbedding(keras.Layer): """Implements a recurrent network for flexibly embedding time vectors.""" diff --git a/bayesflow/networks/embeddings/time2vec.py b/bayesflow/networks/embeddings/time2vec.py index 4c9c3a87f..b52ca77d8 100644 --- a/bayesflow/networks/embeddings/time2vec.py +++ b/bayesflow/networks/embeddings/time2vec.py @@ -5,7 +5,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class Time2Vec(keras.Layer): """ Implements the Time2Vec learnbale embedding from [1]. diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 3c0190467..797d4c62d 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -19,7 +19,7 @@ from ..inference_network import InferenceNetwork -@serializable +@serializable("bayesflow.networks") class FlowMatching(InferenceNetwork): """Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated from [1-3]. diff --git a/bayesflow/networks/mlp/mlp.py b/bayesflow/networks/mlp/mlp.py index 1ac11fe1a..11dcdca2b 100644 --- a/bayesflow/networks/mlp/mlp.py +++ b/bayesflow/networks/mlp/mlp.py @@ -9,7 +9,7 @@ from ..residual import Residual -@serializable +@serializable("bayesflow.networks") class MLP(keras.Sequential): """ Implements a simple configurable MLP with optional residual connections and dropout. diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py index 63094a2a8..402632355 100644 --- a/bayesflow/networks/point_inference_network.py +++ b/bayesflow/networks/point_inference_network.py @@ -7,7 +7,7 @@ from bayesflow.utils.decorators import allow_batch_size -@serializable(package="networks.point_inference_network") +@serializable("bayesflow.networks") class PointInferenceNetwork(keras.Layer): """Implements point estimation for user specified scoring rules by a shared feed forward architecture with separate heads for each scoring rule. diff --git a/bayesflow/networks/residual/residual.py b/bayesflow/networks/residual/residual.py index f2ca54b51..edf32782c 100644 --- a/bayesflow/networks/residual/residual.py +++ b/bayesflow/networks/residual/residual.py @@ -7,7 +7,7 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable +@serializable("bayesflow.networks") class Residual(keras.Sequential): def __init__(self, *layers: keras.Layer, **kwargs): if len(layers) == 1 and isinstance(layers[0], Sequence): diff --git a/bayesflow/networks/time_series_network/skip_recurrent.py b/bayesflow/networks/time_series_network/skip_recurrent.py index 9b2c06c0d..23dee5156 100644 --- a/bayesflow/networks/time_series_network/skip_recurrent.py +++ b/bayesflow/networks/time_series_network/skip_recurrent.py @@ -6,7 +6,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class SkipRecurrentNet(keras.Layer): """ Implements a Skip recurrent layer as described in [1], allowing a more flexible recurrent backbone diff --git a/bayesflow/networks/time_series_network/time_series_network.py b/bayesflow/networks/time_series_network/time_series_network.py index 354806f6c..7a96a099d 100644 --- a/bayesflow/networks/time_series_network/time_series_network.py +++ b/bayesflow/networks/time_series_network/time_series_network.py @@ -7,7 +7,7 @@ from ..summary_network import SummaryNetwork -@serializable +@serializable("bayesflow.networks") class TimeSeriesNetwork(SummaryNetwork): """ Implements a LSTNet Architecture as described in [1] diff --git a/bayesflow/networks/transformers/fusion_transformer.py b/bayesflow/networks/transformers/fusion_transformer.py index 1821c25d2..f416957fb 100644 --- a/bayesflow/networks/transformers/fusion_transformer.py +++ b/bayesflow/networks/transformers/fusion_transformer.py @@ -10,7 +10,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class FusionTransformer(SummaryNetwork): """Implements a more flexible version of the TimeSeriesTransformer that applies a series of self-attention layers followed by cross-attention between the representation and a learnable template summarized via a recurrent net.""" diff --git a/bayesflow/networks/transformers/isab.py b/bayesflow/networks/transformers/isab.py index ae1242469..03f15a561 100644 --- a/bayesflow/networks/transformers/isab.py +++ b/bayesflow/networks/transformers/isab.py @@ -7,7 +7,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class InducedSetAttentionBlock(keras.Layer): """Implements the ISAB block from [1] which represents learnable self-attention specifically designed to deal with large sets via a learnable set of "inducing points". diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 8f0e3f881..5bd7c9dff 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -8,7 +8,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.networks") class MultiHeadAttentionBlock(keras.Layer): """Implements the MAB block from [1] which represents learnable cross-attention. diff --git a/bayesflow/networks/transformers/pma.py b/bayesflow/networks/transformers/pma.py index 956c85b48..bdcb2f983 100644 --- a/bayesflow/networks/transformers/pma.py +++ b/bayesflow/networks/transformers/pma.py @@ -10,7 +10,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class PoolingByMultiHeadAttention(keras.Layer): """Implements the pooling with multi-head attention (PMA) block from [1] which represents a permutation-invariant encoder for set-based inputs. diff --git a/bayesflow/networks/transformers/sab.py b/bayesflow/networks/transformers/sab.py index a447d92a2..276383dfd 100644 --- a/bayesflow/networks/transformers/sab.py +++ b/bayesflow/networks/transformers/sab.py @@ -7,7 +7,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class SetAttentionBlock(MultiHeadAttentionBlock): """Implements the SAB block from [1] which represents learnable self-attention. diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 6c0ab0efc..256d9e54d 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -11,7 +11,7 @@ from .pma import PoolingByMultiHeadAttention -@serializable +@serializable("bayesflow.networks") class SetTransformer(SummaryNetwork): """Implements the set transformer architecture from [1] which ultimately represents a learnable permutation-invariant function. Designed to naturally model interactions in diff --git a/bayesflow/networks/transformers/time_series_transformer.py b/bayesflow/networks/transformers/time_series_transformer.py index 16feca444..007ae8b74 100644 --- a/bayesflow/networks/transformers/time_series_transformer.py +++ b/bayesflow/networks/transformers/time_series_transformer.py @@ -10,7 +10,7 @@ from .mab import MultiHeadAttentionBlock -@serializable +@serializable("bayesflow.networks") class TimeSeriesTransformer(SummaryNetwork): def __init__( self, diff --git a/bayesflow/scores/mean_score.py b/bayesflow/scores/mean_score.py index 553a7c3af..0c7f200b2 100644 --- a/bayesflow/scores/mean_score.py +++ b/bayesflow/scores/mean_score.py @@ -1,9 +1,8 @@ -from keras.saving import register_keras_serializable as serializable - +from bayesflow.utils.serialization import serializable from .normed_difference_score import NormedDifferenceScore -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class MeanScore(NormedDifferenceScore): r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |^2` diff --git a/bayesflow/scores/median_score.py b/bayesflow/scores/median_score.py index 10c8809c3..385c47436 100644 --- a/bayesflow/scores/median_score.py +++ b/bayesflow/scores/median_score.py @@ -1,9 +1,8 @@ -from keras.saving import register_keras_serializable as serializable - +from bayesflow.utils.serialization import serializable from .normed_difference_score import NormedDifferenceScore -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class MedianScore(NormedDifferenceScore): r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |` diff --git a/bayesflow/scores/multivariate_normal_score.py b/bayesflow/scores/multivariate_normal_score.py index 84cfd4910..7c745919c 100644 --- a/bayesflow/scores/multivariate_normal_score.py +++ b/bayesflow/scores/multivariate_normal_score.py @@ -1,15 +1,15 @@ import math import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.links import PositiveDefinite +from bayesflow.utils.serialization import serializable from .parametric_distribution_score import ParametricDistributionScore -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class MultivariateNormalScore(ParametricDistributionScore): r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))` diff --git a/bayesflow/scores/normed_difference_score.py b/bayesflow/scores/normed_difference_score.py index eb2795927..d33bc128f 100644 --- a/bayesflow/scores/normed_difference_score.py +++ b/bayesflow/scores/normed_difference_score.py @@ -1,13 +1,13 @@ import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.utils import weighted_mean +from bayesflow.utils.serialization import serializable from .scoring_rule import ScoringRule -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class NormedDifferenceScore(ScoringRule): r""":math:`S(\hat \theta, \theta; k) = | \hat \theta - \theta |^k` diff --git a/bayesflow/scores/parametric_distribution_score.py b/bayesflow/scores/parametric_distribution_score.py index 3ead3271f..91df32d48 100644 --- a/bayesflow/scores/parametric_distribution_score.py +++ b/bayesflow/scores/parametric_distribution_score.py @@ -1,12 +1,11 @@ -from keras.saving import register_keras_serializable as serializable - from bayesflow.types import Tensor from bayesflow.utils import weighted_mean +from bayesflow.utils.serialization import serializable from .scoring_rule import ScoringRule -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class ParametricDistributionScore(ScoringRule): r""":math:`S(\hat p_\phi, \theta; k) = -\log(\hat p_\phi(\theta))` diff --git a/bayesflow/scores/quantile_score.py b/bayesflow/scores/quantile_score.py index b05a35fc5..7ba021340 100644 --- a/bayesflow/scores/quantile_score.py +++ b/bayesflow/scores/quantile_score.py @@ -1,16 +1,16 @@ from typing import Sequence import keras -from keras.saving import register_keras_serializable as serializable from bayesflow.types import Shape, Tensor from bayesflow.utils import logging, weighted_mean +from bayesflow.utils.serialization import serializable from bayesflow.links import OrderedQuantiles from .scoring_rule import ScoringRule -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class QuantileScore(ScoringRule): r""":math:`S(\hat \theta_i, \theta; \tau_i) = (\hat \theta_i - \theta)(\mathbf{1}_{\hat \theta - \theta > 0} - \tau_i)` diff --git a/bayesflow/scores/scoring_rule.py b/bayesflow/scores/scoring_rule.py index 0144de458..6dee0afec 100644 --- a/bayesflow/scores/scoring_rule.py +++ b/bayesflow/scores/scoring_rule.py @@ -7,7 +7,7 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable(package="bayesflow.scores") +@serializable("bayesflow.scores") class ScoringRule: """Base class for scoring rules. diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py index bb55aee41..5be0e0e1d 100644 --- a/bayesflow/utils/serialization.py +++ b/bayesflow/utils/serialization.py @@ -96,28 +96,49 @@ def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs): @allow_args -def serializable(cls, package: str | None = None, name: str | None = None): - """Register class as Keras serialize. +def serializable(cls, package: str, name: str | None = None, disable_module_check: bool = False): + """Register class as Keras serializable. - Wrapper function around `keras.saving.register_keras_serializable` to automatically - set the `package` and `name` arguments. + Wrapper function around `keras.saving.register_keras_serializable` to automatically check consistency + of the supplied `package` argument with the module a class resides in. The `package` name should generally + be the module the class resides in, truncated at depth two. Valid examples would be "bayesflow.networks" + or "bayesflow.adapters". The check can be disabled if necessary by setting `disable_module_check` to True. + This should only be done in exceptional cases, and accompanied by a comment why it is necessary for a given + class. Parameters ---------- cls : type The class to register. - package : str, optional + package : str `package` argument forwarded to `keras.saving.register_keras_serializable`. - If None is provided, the package is automatically inferred using the __name__ - attribute of the module the class resides in. + Should generally correspond to the module of the class, truncated at depth two (e.g., "bayesflow.networks"). name : str, optional `name` argument forwarded to `keras.saving.register_keras_serializable`. If None is provided, the classe's __name__ attribute is used. + disable_module_check : bool, optional + Disable check that the provided `package` is consistent with the location of the class within the library. + + Raises + ------ + ValueError + If the supplied `package` does not correspond to the module of the class, truncated at depth two, and + `disable_module_check` is False. No error is thrown when a class is not part of the bayesflow module. """ - if package is None: + if not disable_module_check: frame = sys._getframe(2) g = frame.f_globals - package = g.get("__name__", "bayesflow") + module_name = g.get("__name__", "") + # only apply this check if the class is inside the bayesflow module + is_bayesflow = module_name.split(".")[0] == "bayesflow" + auto_package = ".".join(module_name.split(".")[:2]) + if is_bayesflow and package != auto_package: + raise ValueError( + "'package' should be the first two levels of the module the class resides in (e.g., bayesflow.networks)" + f'. In this case it should be \'package="{auto_package}"\' (was "{package}"). If this is not possible' + " (e.g., because a class was moved to a different module, and serializability should be preserved)," + " please set 'disable_module_check=True' and add a comment why it is necessary for this class." + ) if name is None: name = copy(cls.__name__) diff --git a/bayesflow/wrappers/mamba/mamba.py b/bayesflow/wrappers/mamba/mamba.py index d06508790..b328ede98 100644 --- a/bayesflow/wrappers/mamba/mamba.py +++ b/bayesflow/wrappers/mamba/mamba.py @@ -9,7 +9,7 @@ from .mamba_block import MambaBlock -@serializable +@serializable("bayesflow.wrappers") class Mamba(SummaryNetwork): """ Wraps a sequence of Mamba modules using the simple Mamba module from: diff --git a/bayesflow/wrappers/mamba/mamba_block.py b/bayesflow/wrappers/mamba/mamba_block.py index b8ba36d2e..bd15ecc29 100644 --- a/bayesflow/wrappers/mamba/mamba_block.py +++ b/bayesflow/wrappers/mamba/mamba_block.py @@ -6,7 +6,7 @@ from bayesflow.utils.serialization import serializable -@serializable +@serializable("bayesflow.wrappers") class MambaBlock(keras.Layer): """ Wraps the original Mamba module from, with added functionality for bidirectional processing: diff --git a/docsrc/source/development/serialization.md b/docsrc/source/development/serialization.md index 15c812686..ec8988454 100644 --- a/docsrc/source/development/serialization.md +++ b/docsrc/source/development/serialization.md @@ -17,10 +17,19 @@ As we want/need to pass them in some places, we have to resort to some custom be BayesFlows serialization utilities can be found in the {py:mod}`~bayesflow.utils.serialization` module. We mainly provide three convenience functions: -- The {py:func}`~bayesflow.utils.serialization.serializable` decorator wraps the `keras.saving.register_keras_serializable` function to provide automatic `package` and `name` arguments. +- The {py:func}`~bayesflow.utils.serialization.serializable` decorator wraps the `keras.saving.register_keras_serializable` function to ensure consistent naming of the `package` argument within the library. - The {py:func}`~bayesflow.utils.serialization.serialize` function, which adds support for serializing classes. - Its counterpart {py:func}`~bayesflow.utils.serialization.deserialize`, adds support to deserialize classes. ## Usage To use the adapted serialization functions, you have to use them in the `get_config` and `from_config` method. Please refer to existing classes in the library for usage examples. + +### The `serializable` Decorator + +To make serialization as little confusing as possible, as well as providing stability even when moving classes around, we provide the `package` argument explicitly for each class. +The naming should respect the following naming scheme: Take the module the class resides in (for example, `bayesflow.adapters.transforms.standardize`), and truncate the path to depth two (`bayesflow.adapters`). +In cases where this convention cannot be followed, set `disable_module_check` to `True`, and describe why a different name was necessary. +Changing `package` breaks backwards-compatibility for serialization, so it should be avoided whenever possible. +If you move a class to a different module (without changing the class itself), keep the `package` and set `disable_module_check` to `True`. +This may later be adapted in a release that breaks backward compatiblity anyways. diff --git a/tests/test_utils/test_serialize_deserialize.py b/tests/test_utils/test_serialize_deserialize.py index 6c3bc3983..a9888ecc5 100644 --- a/tests/test_utils/test_serialize_deserialize.py +++ b/tests/test_utils/test_serialize_deserialize.py @@ -3,7 +3,7 @@ from bayesflow.utils.serialization import deserialize, serializable, serialize -@serializable +@serializable("test", disable_module_check=True) class Foo: @classmethod def from_config(cls, config, custom_objects=None): @@ -13,7 +13,7 @@ def get_config(self): return {} -@serializable +@serializable("test", disable_module_check=True) class Bar: @classmethod def from_config(cls, config, custom_objects=None): diff --git a/tests/test_workflows/conftest.py b/tests/test_workflows/conftest.py index c98e543e9..e66b6efe4 100644 --- a/tests/test_workflows/conftest.py +++ b/tests/test_workflows/conftest.py @@ -40,7 +40,7 @@ def summary_network(request): elif request.param == "custom": from bayesflow.networks import SummaryNetwork - @serializable + @serializable("test", disable_module_check=True) class Custom(SummaryNetwork): def __init__(self, **kwargs): super().__init__(**kwargs) From 335e54674af7b3f561ea6f980ea05dd6de1676ec Mon Sep 17 00:00:00 2001 From: LarsKue Date: Mon, 5 May 2025 10:43:14 -0400 Subject: [PATCH 45/46] improve adding __version__ attribute --- bayesflow/__init__.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/bayesflow/__init__.py b/bayesflow/__init__.py index 5a28ffe2e..008afc89b 100644 --- a/bayesflow/__init__.py +++ b/bayesflow/__init__.py @@ -50,17 +50,11 @@ def setup(): "in contexts where you need gradients (e.g. custom training loops)." ) + # dynamically add __version__ attribute + from importlib.metadata import version -# dynamically add version dunder variable -try: - from importlib.metadata import version, PackageNotFoundError + globals()["__version__"] = version("bayesflow") - __version__ = version(__name__) -except PackageNotFoundError: - __version__ = "2.0.0" -finally: - del version - del PackageNotFoundError # call and clean up namespace setup() From 52bdb581300aaf5053baca3dccd567e415be29a1 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Mon, 5 May 2025 10:43:22 -0400 Subject: [PATCH 46/46] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a42b1df97..07a5e924b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "bayesflow" -version = "2.0.2" +version = "2.0.3" authors = [{ name = "The BayesFlow Team" }] classifiers = [ "Development Status :: 5 - Production/Stable",