Skip to content

Commit 60050ec

Browse files
committed
Merge branch 'dev' into point-estimation
2 parents aa5dd93 + 77dee84 commit 60050ec

File tree

157 files changed

+2609
-742
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

157 files changed

+2609
-742
lines changed

.github/workflows/docs.yaml renamed to .github/workflows/multiversion-docs.yaml

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11

22
# From https://github.com/eeholmes/readthedoc-test/blob/main/.github/workflows/docs_pages.yml
3-
name: docs
3+
name: multiversion-docs
44

5-
# execute this workflow automatically when we push to master
65
on:
7-
push:
8-
branches:
9-
- master
6+
workflow_dispatch:
7+
# execute this workflow automatically when we push to master or dev
8+
# push:
9+
# branches:
10+
# - master
11+
# - dev
1012

1113
jobs:
1214

@@ -17,13 +19,15 @@ jobs:
1719
- name: Checkout main
1820
uses: actions/checkout@v3
1921
with:
20-
path: master
22+
path: dev
23+
fetch-depth: 0
24+
fetch-tags: true
2125

22-
- name: Checkout gh-pages
26+
- name: Checkout gh-pages-dev
2327
uses: actions/checkout@v3
2428
with:
25-
path: gh-pages
26-
ref: gh-pages
29+
path: gh-pages-dev
30+
ref: gh-pages-dev
2731

2832
- name: Set up Python
2933
uses: actions/setup-python@v4
@@ -33,17 +37,21 @@ jobs:
3337

3438
- name: Install dependencies
3539
run: |
36-
cd ./master
40+
cd ./dev
3741
python -m pip install .[docs]
42+
- name: Create local branches
43+
run: |
44+
cd ./dev
45+
git branch master remotes/origin/master
3846
- name: Make the Sphinx docs
3947
run: |
40-
cd ./master/docsrc
48+
cd ./dev/docsrc
4149
make clean
4250
make github
4351
- name: Commit changes to docs
4452
run: |
45-
cd ./gh-pages
46-
cp -R ../master/docs/* ./
53+
cd ./gh-pages-dev
54+
cp -R ../dev/docs/* ./
4755
git config --local user.email ""
4856
git config --local user.name "github-actions"
4957
git add -A

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ __pycache__/
55
projects/
66
*/bayesflow.egg-info
77
docsrc/_build/
8+
docsrc/_build_polyversion
9+
docsrc/.bf_doc_gen_venv
10+
docsrc/source/api
11+
docsrc/source/_examples
12+
docsrc/source/contributing.md
813
build
914
docs/
1015

CONTRIBUTING.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ Make sure to occasionally also run multi-backend tests for your OS using [tox](h
100100
tox --parallel auto
101101
```
102102

103-
See [tox.ini](tox.ini) for details on the environment configurations.
103+
See `tox.ini` for details on the environment configurations.
104104
Multi-OS tests will automatically be run once you create a pull request.
105105

106106
Note that to be backend-agnostic, your code must not:
@@ -137,12 +137,24 @@ z = keras.ops.convert_to_numpy(x)
137137
### 4. Document your changes
138138

139139
The documentation uses [sphinx](https://www.sphinx-doc.org/) and relies on [numpy style docstrings](https://numpydoc.readthedocs.io/en/latest/format.html) in classes and functions.
140-
The overall *structure* of the documentation is manually designed. This also applies to the API documentation. This has two implications for you:
141140

142-
1. If you add to existing submodules, the documentation will update automatically (given that you use proper numpy docstrings).
143-
2. If you add a new submodule or subpackage, you need to add a file to `docsrc/source/api` and a reference to the new module to the appropriate section of `docsrc/source/api/bayesflow.rst`.
141+
Run the following command to install all necessary packages for setting up documentation generation:
144142

145-
You can re-build the documentation with
143+
```
144+
pip install .[docs]
145+
```
146+
147+
The overall *structure* of the documentation is manually designed, but the API documentation is auto-generated.
148+
149+
You can re-build the current documentation with
150+
151+
```bash
152+
cd docsrc
153+
make clean && make dev
154+
# in case of issues, try `make clean-all`
155+
```
156+
157+
We also provide a multi-version documentation. To generate it, run
146158

147159
```bash
148160
cd docsrc

bayesflow/adapters/adapter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,18 @@ def as_time_series(self, keys: str | Sequence[str]):
122122
return self
123123

124124
def broadcast(
125-
self, keys: str | Sequence[str], *, to: str, expand: str | int | tuple = "left", exclude: int | tuple = -1
125+
self,
126+
keys: str | Sequence[str],
127+
*,
128+
to: str,
129+
expand: str | int | tuple = "left",
130+
exclude: int | tuple = -1,
131+
squeeze: int | tuple = None,
126132
):
127133
if isinstance(keys, str):
128134
keys = [keys]
129135

130-
transform = Broadcast(keys, to=to, expand=expand, exclude=exclude)
136+
transform = Broadcast(keys, to=to, expand=expand, exclude=exclude, squeeze=squeeze)
131137
self.transforms.append(transform)
132138
return self
133139

bayesflow/adapters/transforms/broadcast.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,15 @@ class Broadcast(Transform):
5353
It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform.
5454
"""
5555

56-
def __init__(self, keys: Sequence[str], *, to: str, expand: str | int | tuple = "left", exclude: int | tuple = -1):
56+
def __init__(
57+
self,
58+
keys: Sequence[str],
59+
*,
60+
to: str,
61+
expand: str | int | tuple = "left",
62+
exclude: int | tuple = -1,
63+
squeeze: int | tuple = None,
64+
):
5765
super().__init__()
5866
self.keys = keys
5967
self.to = to
@@ -67,6 +75,7 @@ def __init__(self, keys: Sequence[str], *, to: str, expand: str | int | tuple =
6775
exclude = (exclude,)
6876

6977
self.exclude = exclude
78+
self.squeeze = squeeze
7079

7180
@classmethod
7281
def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
@@ -75,6 +84,7 @@ def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
7584
to=deserialize(config["to"], custom_objects),
7685
expand=deserialize(config["expand"], custom_objects),
7786
exclude=deserialize(config["exclude"], custom_objects),
87+
squeeze=deserialize(config["squeeze"], custom_objects),
7888
)
7989

8090
def get_config(self) -> dict:
@@ -83,6 +93,7 @@ def get_config(self) -> dict:
8393
"to": serialize(self.to),
8494
"expand": serialize(self.expand),
8595
"exclude": serialize(self.exclude),
96+
"squeeze": serialize(self.squeeze),
8697
}
8798

8899
# noinspection PyMethodOverriding
@@ -115,6 +126,9 @@ def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray
115126

116127
data[k] = np.broadcast_to(data[k], new_shape)
117128

129+
if self.squeeze is not None:
130+
data[k] = np.squeeze(data[k], axis=self.squeeze)
131+
118132
return data
119133

120134
# noinspection PyMethodOverriding

bayesflow/approximators/continuous_approximator.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from bayesflow.adapters import Adapter
1212
from bayesflow.networks import InferenceNetwork, SummaryNetwork
1313
from bayesflow.types import Tensor
14-
from bayesflow.utils import logging, split_arrays
14+
from bayesflow.utils import filter_kwargs, logging, split_arrays
1515
from .approximator import Approximator
1616

1717

@@ -141,7 +141,7 @@ def sample(
141141
) -> dict[str, np.ndarray]:
142142
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
143143
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
144-
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions)}
144+
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)}
145145
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
146146
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
147147

@@ -154,6 +154,7 @@ def _sample(
154154
num_samples: int,
155155
inference_conditions: Tensor = None,
156156
summary_variables: Tensor = None,
157+
**kwargs,
157158
) -> Tensor:
158159
if self.summary_network is None:
159160
if summary_variables is not None:
@@ -162,7 +163,9 @@ def _sample(
162163
if summary_variables is None:
163164
raise ValueError("Summary variables are required when a summary network is present.")
164165

165-
summary_outputs = self.summary_network(summary_variables)
166+
summary_outputs = self.summary_network(
167+
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
168+
)
166169

167170
if inference_conditions is None:
168171
inference_conditions = summary_outputs
@@ -180,18 +183,26 @@ def _sample(
180183
else:
181184
batch_shape = (num_samples,)
182185

183-
return self.inference_network.sample(batch_shape, conditions=inference_conditions)
186+
return self.inference_network.sample(
187+
batch_shape,
188+
conditions=inference_conditions,
189+
**filter_kwargs(kwargs, self.inference_network.sample),
190+
)
184191

185-
def log_prob(self, data: dict[str, np.ndarray]) -> np.ndarray:
186-
data = self.adapter(data, strict=False, stage="inference")
192+
def log_prob(self, data: dict[str, np.ndarray], **kwargs) -> np.ndarray:
193+
data = self.adapter(data, strict=False, stage="inference", **kwargs)
187194
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
188-
log_prob = self._log_prob(**data)
195+
log_prob = self._log_prob(**data, **kwargs)
189196
log_prob = keras.ops.convert_to_numpy(log_prob)
190197

191198
return log_prob
192199

193200
def _log_prob(
194-
self, inference_variables: Tensor, inference_conditions: Tensor = None, summary_variables: Tensor = None
201+
self,
202+
inference_variables: Tensor,
203+
inference_conditions: Tensor = None,
204+
summary_variables: Tensor = None,
205+
**kwargs,
195206
) -> Tensor:
196207
if self.summary_network is None:
197208
if summary_variables is not None:
@@ -200,11 +211,17 @@ def _log_prob(
200211
if summary_variables is None:
201212
raise ValueError("Summary variables are required when a summary network is present.")
202213

203-
summary_outputs = self.summary_network(summary_variables)
214+
summary_outputs = self.summary_network(
215+
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
216+
)
204217

205218
if inference_conditions is None:
206219
inference_conditions = summary_outputs
207220
else:
208221
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
209222

210-
return self.inference_network.log_prob(inference_variables, conditions=inference_conditions)
223+
return self.inference_network.log_prob(
224+
inference_variables,
225+
conditions=inference_conditions,
226+
**filter_kwargs(kwargs, self.inference_network.log_prob),
227+
)

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from bayesflow.types import Tensor
10-
from bayesflow.utils import find_network, keras_kwargs
10+
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type
1111

1212

1313
from ..inference_network import InferenceNetwork
@@ -88,6 +88,27 @@ def __init__(
8888

8989
self.seed_generator = keras.random.SeedGenerator()
9090

91+
# serialization: store all parameters necessary to call __init__
92+
self.config = {
93+
"total_steps": total_steps,
94+
"max_time": max_time,
95+
"sigma2": sigma2,
96+
"eps": eps,
97+
"s0": s0,
98+
"s1": s1,
99+
**kwargs,
100+
}
101+
self.config = serialize_value_or_type(self.config, "subnet", subnet)
102+
103+
def get_config(self):
104+
base_config = super().get_config()
105+
return base_config | self.config
106+
107+
@classmethod
108+
def from_config(cls, config):
109+
config = deserialize_value_or_type(config, "subnet")
110+
return cls(**config)
111+
91112
def _schedule_discretization(self, step) -> float:
92113
"""Schedule function for adjusting the discretization level `N` during
93114
the course of training.

bayesflow/networks/consistency_models/continuous_consistency_model.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,16 @@
77
import numpy as np
88

99
from bayesflow.types import Tensor
10-
from bayesflow.utils import jvp, concatenate, find_network, keras_kwargs, expand_right_as, expand_right_to
10+
from bayesflow.utils import (
11+
jvp,
12+
concatenate,
13+
find_network,
14+
keras_kwargs,
15+
expand_right_as,
16+
expand_right_to,
17+
serialize_value_or_type,
18+
deserialize_value_or_type,
19+
)
1120

1221

1322
from ..inference_network import InferenceNetwork
@@ -62,6 +71,22 @@ def __init__(
6271

6372
self.seed_generator = keras.random.SeedGenerator()
6473

74+
# serialization: store all parameters necessary to call __init__
75+
self.config = {
76+
"sigma_data": sigma_data,
77+
**kwargs,
78+
}
79+
self.config = serialize_value_or_type(self.config, "subnet", subnet)
80+
81+
def get_config(self):
82+
base_config = super().get_config()
83+
return base_config | self.config
84+
85+
@classmethod
86+
def from_config(cls, config):
87+
config = deserialize_value_or_type(config, "subnet")
88+
return cls(**config)
89+
6590
def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs):
6691
t = np.linspace(0.0, np.pi / 2, num_steps)
6792
times = np.exp((t - np.pi / 2) * rho) * np.pi / 2

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from keras.saving import register_keras_serializable as serializable
33

44
from bayesflow.types import Tensor
5-
from bayesflow.utils import find_permutation, keras_kwargs
5+
from bayesflow.utils import find_permutation, keras_kwargs, serialize_value_or_type, deserialize_value_or_type
66

77
from .actnorm import ActNorm
88
from .couplings import DualCoupling
@@ -58,13 +58,33 @@ def __init__(
5858

5959
self.invertible_layers.append(DualCoupling(subnet, transform, **kwargs.get("coupling_kwargs", {})))
6060

61+
# serialization: store all parameters necessary to call __init__
62+
self.config = {
63+
"depth": depth,
64+
"transform": transform,
65+
"permutation": permutation,
66+
"use_actnorm": use_actnorm,
67+
"base_distribution": base_distribution,
68+
**kwargs,
69+
}
70+
self.config = serialize_value_or_type(self.config, "subnet", subnet)
71+
6172
# noinspection PyMethodOverriding
6273
def build(self, xz_shape, conditions_shape=None):
6374
super().build(xz_shape)
6475

6576
for layer in self.invertible_layers:
6677
layer.build(xz_shape=xz_shape, conditions_shape=conditions_shape)
6778

79+
def get_config(self):
80+
base_config = super().get_config()
81+
return base_config | self.config
82+
83+
@classmethod
84+
def from_config(cls, config):
85+
config = deserialize_value_or_type(config, "subnet")
86+
return cls(**config)
87+
6888
def _forward(
6989
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
7090
) -> Tensor | tuple[Tensor, Tensor]:

0 commit comments

Comments
 (0)