Skip to content

Commit 5f11724

Browse files
authored
Merge branch 'main' into feat-diffusion-model
2 parents 49c0cb7 + d31a761 commit 5f11724

File tree

149 files changed

+3765
-3079
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

149 files changed

+3765
-3079
lines changed

.github/workflows/publish.yaml

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,53 @@
1+
name: Publish Python 🐍 distribution 📦 to PyPI
12

2-
name: Publish to PyPI.org
33
on:
44
release:
55
types: [published]
6+
67
jobs:
7-
pypi:
8+
build:
9+
name: Build distribution 📦
10+
runs-on: ubuntu-latest
11+
12+
steps:
13+
- uses: actions/checkout@v4
14+
with:
15+
persist-credentials: false
16+
- name: Set up Python
17+
uses: actions/setup-python@v5
18+
with:
19+
python-version: "3.x"
20+
- name: Install pypa/build
21+
run: >-
22+
python3 -m
23+
pip install
24+
build
25+
--user
26+
- name: Build a binary wheel and a source tarball
27+
run: python3 -m build
28+
- name: Store the distribution packages
29+
uses: actions/upload-artifact@v4
30+
with:
31+
name: python-package-distributions
32+
path: dist/
33+
34+
publish-to-pypi:
35+
name: >-
36+
Publish Python 🐍 distribution 📦 to PyPI
37+
needs:
38+
- build
839
runs-on: ubuntu-latest
40+
environment:
41+
name: pypi
42+
url: https://pypi.org/p/bayesflow # Replace <package-name> with your PyPI project name
43+
permissions:
44+
id-token: write # IMPORTANT: mandatory for trusted publishing
45+
946
steps:
10-
- name: Checkout
11-
uses: actions/checkout@v4
12-
with:
13-
fetch-depth: 0
14-
- run: python3 -m pip install -U build && python3 -m build
15-
- name: Publish package
16-
uses: pypa/gh-action-pypi-publish@release/v1
17-
with:
18-
password: ${{ secrets.PYPI_API_TOKEN }}
47+
- name: Download all the dists
48+
uses: actions/download-artifact@v4
49+
with:
50+
name: python-package-distributions
51+
path: dist/
52+
- name: Publish distribution 📦 to PyPI
53+
uses: pypa/gh-action-pypi-publish@release/v1

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,19 @@ More tutorials are always welcome! Please consider making a pull request if you
6464

6565
## Install
6666

67-
BayesFlow v2 is not yet installable via PyPI, but you can use the following command to install the latest version of the `main` branch:
67+
You can install the latest stable version from PyPI using:
6868

6969
```bash
70-
pip install git+https://github.com/bayesflow-org/bayesflow.git
70+
pip install bayesflow
7171
```
7272

73-
If you encounter problems with this or require more control, please refer to the instructions to install from source below.
73+
If you want the latest features, you can install from source:
7474

75-
Note: `pip install bayesflow` will install the v1 version of BayesFlow.
75+
```bash
76+
pip install git+https://github.com/bayesflow-org/bayesflow.git@dev
77+
```
78+
79+
If you encounter problems with this or require more control, please refer to the instructions to install from source below.
7680

7781
### Backend
7882

bayesflow/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ def setup():
5151
)
5252

5353

54+
# dynamically add version dunder variable
55+
try:
56+
from importlib.metadata import version, PackageNotFoundError
57+
58+
__version__ = version(__name__)
59+
except PackageNotFoundError:
60+
__version__ = "2.0.0"
61+
finally:
62+
del version
63+
del PackageNotFoundError
64+
5465
# call and clean up namespace
5566
setup()
5667
del setup

bayesflow/adapters/adapter.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22

33
import numpy as np
44

5-
from keras.saving import (
6-
deserialize_keras_object as deserialize,
7-
register_keras_serializable as serializable,
8-
serialize_keras_object as serialize,
9-
)
5+
from bayesflow.utils.serialization import deserialize, serialize, serializable
106

117
from .transforms import (
128
AsSet,
@@ -33,7 +29,7 @@
3329
from .transforms.filter_transform import Predicate
3430

3531

36-
@serializable(package="bayesflow.adapters")
32+
@serializable
3733
class Adapter(MutableSequence[Transform]):
3834
"""
3935
Defines an adapter to apply various transforms to data.
@@ -74,18 +70,24 @@ def create_default(inference_variables: Sequence[str]) -> "Adapter":
7470

7571
@classmethod
7672
def from_config(cls, config: dict, custom_objects=None) -> "Adapter":
77-
return cls(transforms=deserialize(config["transforms"], custom_objects))
73+
return cls(**deserialize(config, custom_objects=custom_objects))
7874

7975
def get_config(self) -> dict:
80-
return {"transforms": serialize(self.transforms)}
76+
config = {
77+
"transforms": self.transforms,
78+
}
79+
80+
return serialize(config)
8181

82-
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
82+
def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]:
8383
"""Apply the transforms in the forward direction.
8484
8585
Parameters
8686
----------
8787
data : dict
8888
The data to be transformed.
89+
stage : str, one of ["training", "validation", "inference"]
90+
The stage the function is called in.
8991
**kwargs : dict
9092
Additional keyword arguments passed to each transform.
9193
@@ -97,17 +99,19 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
9799
data = data.copy()
98100

99101
for transform in self.transforms:
100-
data = transform(data, **kwargs)
102+
data = transform(data, stage=stage, **kwargs)
101103

102104
return data
103105

104-
def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
106+
def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]:
105107
"""Apply the transforms in the inverse direction.
106108
107109
Parameters
108110
----------
109111
data : dict
110112
The data to be transformed.
113+
stage : str, one of ["training", "validation", "inference"]
114+
The stage the function is called in.
111115
**kwargs : dict
112116
Additional keyword arguments passed to each transform.
113117
@@ -119,11 +123,13 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
119123
data = data.copy()
120124

121125
for transform in reversed(self.transforms):
122-
data = transform(data, inverse=True, **kwargs)
126+
data = transform(data, stage=stage, inverse=True, **kwargs)
123127

124128
return data
125129

126-
def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]:
130+
def __call__(
131+
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
132+
) -> dict[str, np.ndarray]:
127133
"""Apply the transforms in the given direction.
128134
129135
Parameters
@@ -132,6 +138,8 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
132138
The data to be transformed.
133139
inverse : bool, optional
134140
If False, apply the forward transform, else apply the inverse transform (default False).
141+
stage : str, one of ["training", "validation", "inference"]
142+
The stage the function is called in.
135143
**kwargs
136144
Additional keyword arguments passed to each transform.
137145
@@ -141,9 +149,9 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
141149
The transformed data.
142150
"""
143151
if inverse:
144-
return self.inverse(data, **kwargs)
152+
return self.inverse(data, stage=stage, **kwargs)
145153

146-
return self.forward(data, **kwargs)
154+
return self.forward(data, stage=stage, **kwargs)
147155

148156
def __repr__(self):
149157
result = ""
@@ -667,6 +675,18 @@ def shift(self, keys: str | Sequence[str], by: float | np.ndarray):
667675
self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys}))
668676
return self
669677

678+
def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Sequence[int] = None, axis: int = -1):
679+
from .transforms import Split
680+
681+
if isinstance(into, str):
682+
transform = Rename(key, into)
683+
else:
684+
transform = Split(key, into, indices_or_sections, axis)
685+
686+
self.transforms.append(transform)
687+
688+
return self
689+
670690
def sqrt(self, keys: str | Sequence[str]):
671691
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
672692
@@ -743,3 +763,10 @@ def to_array(
743763
)
744764
self.transforms.append(transform)
745765
return self
766+
767+
def to_dict(self):
768+
from .transforms import ToDict
769+
770+
transform = ToDict()
771+
self.transforms.append(transform)
772+
return self

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from .scale import Scale
1818
from .serializable_custom_transform import SerializableCustomTransform
1919
from .shift import Shift
20+
from .split import Split
2021
from .sqrt import Sqrt
2122
from .standardize import Standardize
2223
from .to_array import ToArray
24+
from .to_dict import ToDict
2325
from .transform import Transform
2426

2527
from ...utils._docs import _add_imports_to_all

bayesflow/adapters/transforms/as_set.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from keras.saving import register_keras_serializable as serializable
21
import numpy as np
32

3+
from bayesflow.utils.serialization import serializable
4+
45
from .elementwise_transform import ElementwiseTransform
56

67

7-
@serializable(package="bayesflow.adapters")
8+
@serializable
89
class AsSet(ElementwiseTransform):
910
"""The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets.
1011
@@ -33,9 +34,5 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
3334

3435
return data
3536

36-
@classmethod
37-
def from_config(cls, config: dict, custom_objects=None) -> "AsSet":
38-
return cls()
39-
4037
def get_config(self) -> dict:
4138
return {}

bayesflow/adapters/transforms/as_time_series.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
2-
from keras.saving import register_keras_serializable as serializable
2+
3+
from bayesflow.utils.serialization import serializable
34

45
from .elementwise_transform import ElementwiseTransform
56

67

7-
@serializable(package="bayesflow.adapters")
8+
@serializable
89
class AsTimeSeries(ElementwiseTransform):
910
"""The `.as_time_series` transform can be used to indicate that variables shall be treated as time series.
1011
@@ -29,9 +30,5 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
2930

3031
return data
3132

32-
@classmethod
33-
def from_config(cls, config: dict, custom_objects=None) -> "AsTimeSeries":
34-
return cls()
35-
3633
def get_config(self) -> dict:
3734
return {}

bayesflow/adapters/transforms/broadcast.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from collections.abc import Sequence
22
import numpy as np
33

4-
from keras.saving import (
5-
deserialize_keras_object as deserialize,
6-
register_keras_serializable as serializable,
7-
serialize_keras_object as serialize,
8-
)
4+
from bayesflow.utils.serialization import serialize, serializable
95

106
from .transform import Transform
117

128

13-
@serializable(package="bayesflow.adapters")
9+
@serializable
1410
class Broadcast(Transform):
1511
"""
1612
Broadcasts arrays or scalars to the shape of a given other array.
@@ -96,31 +92,15 @@ def __init__(
9692
self.exclude = exclude
9793
self.squeeze = squeeze
9894

99-
@classmethod
100-
def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
101-
# Deserialize turns tuples to lists, undo it if necessary
102-
exclude = deserialize(config["exclude"], custom_objects)
103-
exclude = tuple(exclude) if isinstance(exclude, list) else exclude
104-
expand = deserialize(config["expand"], custom_objects)
105-
expand = tuple(expand) if isinstance(expand, list) else expand
106-
squeeze = deserialize(config["squeeze"], custom_objects)
107-
squeeze = tuple(squeeze) if isinstance(squeeze, list) else squeeze
108-
return cls(
109-
keys=deserialize(config["keys"], custom_objects),
110-
to=deserialize(config["to"], custom_objects),
111-
expand=expand,
112-
exclude=exclude,
113-
squeeze=squeeze,
114-
)
115-
11695
def get_config(self) -> dict:
117-
return {
118-
"keys": serialize(self.keys),
119-
"to": serialize(self.to),
120-
"expand": serialize(self.expand),
121-
"exclude": serialize(self.exclude),
122-
"squeeze": serialize(self.squeeze),
96+
config = {
97+
"keys": self.keys,
98+
"to": self.to,
99+
"expand": self.expand,
100+
"exclude": self.exclude,
101+
"squeeze": self.squeeze,
123102
}
103+
return serialize(config)
124104

125105
# noinspection PyMethodOverriding
126106
def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:

0 commit comments

Comments
 (0)