Skip to content

Commit c963b5b

Browse files
committed
Merge branch 'dev' into adapater_nan
# Conflicts: # bayesflow/adapters/adapter.py # bayesflow/adapters/transforms/__init__.py
2 parents 3a8e313 + 4781e2e commit c963b5b

File tree

230 files changed

+4452
-573
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

230 files changed

+4452
-573
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ docs/
3939

4040
# MacOS
4141
.DS_Store
42+
43+
# Rproj
44+
.Rproj.user

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ It provides users and researchers with:
1515
BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference
1616
fueled by continuous progress in generative AI and Bayesian inference.
1717

18+
> [!IMPORTANT]
19+
> As the 2.0 version introduced many new features, we still have to make breaking changes from time to time.
20+
> This especially concerns **saving and loading** of models. We aim to stabilize this from the 2.1 release onwards.
21+
> Until then, consider pinning your BayesFlow 2.0 installation to an exact version, or re-training after an update
22+
> for less costly models.
23+
1824
## Important Note for Existing Users
1925

2026
You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library.

bayesflow/__init__.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,11 @@ def setup():
5050
"in contexts where you need gradients (e.g. custom training loops)."
5151
)
5252

53+
# dynamically add __version__ attribute
54+
from importlib.metadata import version
5355

54-
# dynamically add version dunder variable
55-
try:
56-
from importlib.metadata import version, PackageNotFoundError
56+
globals()["__version__"] = version("bayesflow")
5757

58-
__version__ = version(__name__)
59-
except PackageNotFoundError:
60-
__version__ = "2.0.0"
61-
finally:
62-
del version
63-
del PackageNotFoundError
6458

6559
# call and clean up namespace
6660
setup()

bayesflow/adapters/adapter.py

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,28 @@
1414
Drop,
1515
ExpandDims,
1616
FilterTransform,
17+
Group,
1718
Keep,
1819
Log,
1920
MapTransform,
2021
NumpyTransform,
2122
OneHot,
2223
Rename,
2324
SerializableCustomTransform,
25+
Squeeze,
2426
Sqrt,
2527
Standardize,
2628
ToArray,
2729
Transform,
30+
Ungroup,
31+
RandomSubsample,
32+
Take,
2833
NanToNum,
2934
)
3035
from .transforms.filter_transform import Predicate
3136

3237

33-
@serializable
38+
@serializable("bayesflow.adapters")
3439
class Adapter(MutableSequence[Transform]):
3540
"""
3641
Defines an adapter to apply various transforms to data.
@@ -599,6 +604,52 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple):
599604
self.transforms.append(transform)
600605
return self
601606

607+
def group(self, keys: Sequence[str], into: str, *, prefix: str = ""):
608+
"""Append a :py:class:`~transforms.Group` transform to the adapter.
609+
610+
Groups the given variables as a dictionary in the key `into`. As most transforms do
611+
not support nested structures, this should usually be the last transform in the adapter.
612+
613+
Parameters
614+
----------
615+
keys : Sequence of str
616+
The names of the variables to group together.
617+
into : str
618+
The name of the variable to store the grouped variables in.
619+
prefix : str, optional
620+
An optional common prefix of the variable names before grouping, which will be removed after grouping.
621+
622+
Raises
623+
------
624+
ValueError
625+
If a prefix is specified, but a provided key does not start with the prefix.
626+
"""
627+
if isinstance(keys, str):
628+
keys = [keys]
629+
630+
transform = Group(keys=keys, into=into, prefix=prefix)
631+
self.transforms.append(transform)
632+
return self
633+
634+
def ungroup(self, key: str, *, prefix: str = ""):
635+
"""Append an :py:class:`~transforms.Ungroup` transform to the adapter.
636+
637+
Ungroups the the variables in `key` from a dictionary into individual entries. Most transforms do
638+
not support nested structures, so this can be used to flatten a nested structure.
639+
The nesting can be re-established after the transforms using the :py:meth:`group` method.
640+
641+
Parameters
642+
----------
643+
key : str
644+
The name of the variable to ungroup. The corresponding variable has to be a dictionary.
645+
prefix : str, optional
646+
An optional common prefix that will be added to the ungrouped variable names. This can be necessary
647+
to avoid duplicate names.
648+
"""
649+
transform = Ungroup(key=key, prefix=prefix)
650+
self.transforms.append(transform)
651+
return self
652+
602653
def keep(self, keys: str | Sequence[str]):
603654
"""Append a :py:class:`~transforms.Keep` transform to the adapter.
604655
@@ -666,6 +717,28 @@ def one_hot(self, keys: str | Sequence[str], num_classes: int):
666717
self.transforms.append(transform)
667718
return self
668719

720+
def random_subsample(self, key: str, *, sample_size: int | float, axis: int = -1):
721+
"""
722+
Append a :py:class:`~transforms.RandomSubsample` transform to the adapter.
723+
724+
Parameters
725+
----------
726+
key : str or Sequence of str
727+
The name of the variable to subsample.
728+
sample_size : int or float
729+
The number of samples to draw, or a fraction between 0 and 1 of the total number of samples to draw.
730+
axis: int, optional
731+
Which axis to draw samples over. The last axis is used by default.
732+
"""
733+
734+
if not isinstance(key, str):
735+
raise TypeError("Can only subsample one batch entry at a time.")
736+
737+
transform = MapTransform({key: RandomSubsample(sample_size=sample_size, axis=axis)})
738+
739+
self.transforms.append(transform)
740+
return self
741+
669742
def rename(self, from_key: str, to_key: str):
670743
"""Append a :py:class:`~transforms.Rename` transform to the adapter.
671744
@@ -709,6 +782,24 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq
709782

710783
return self
711784

785+
def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple):
786+
"""Append a :py:class:`~transforms.Squeeze` transform to the adapter.
787+
788+
Parameters
789+
----------
790+
keys : str or Sequence of str
791+
The names of the variables to squeeze.
792+
axis : int or tuple
793+
The axis to squeeze. As the number of batch dimensions might change, we advise using negative
794+
numbers (i.e., indexing from the end instead of the start).
795+
"""
796+
if isinstance(keys, str):
797+
keys = [keys]
798+
799+
transform = MapTransform({key: Squeeze(axis=axis) for key in keys})
800+
self.transforms.append(transform)
801+
return self
802+
712803
def sqrt(self, keys: str | Sequence[str]):
713804
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
714805
@@ -742,7 +833,7 @@ def standardize(
742833
Names of variables to include in the transform.
743834
exclude : str or Sequence of str, optional
744835
Names of variables to exclude from the transform.
745-
**kwargs : dict
836+
**kwargs :
746837
Additional keyword arguments passed to the transform.
747838
"""
748839
transform = FilterTransform(
@@ -755,6 +846,42 @@ def standardize(
755846
self.transforms.append(transform)
756847
return self
757848

849+
def take(
850+
self,
851+
include: str | Sequence[str] = None,
852+
*,
853+
indices: Sequence[int],
854+
axis: int = -1,
855+
predicate: Predicate = None,
856+
exclude: str | Sequence[str] = None,
857+
):
858+
"""
859+
Append a :py:class:`~transforms.Take` transform to the adapter.
860+
861+
Parameters
862+
----------
863+
include : str or Sequence of str, optional
864+
Names of variables to include in the transform.
865+
indices : Sequence of int
866+
Which indices to take from the data.
867+
axis : int, optional
868+
Which axis to take from. The last axis is used by default.
869+
predicate : Predicate, optional
870+
Function that indicates which variables should be transformed.
871+
exclude : str or Sequence of str, optional
872+
Names of variables to exclude from the transform.
873+
"""
874+
transform = FilterTransform(
875+
transform_constructor=Take,
876+
predicate=predicate,
877+
include=include,
878+
exclude=exclude,
879+
indices=indices,
880+
axis=axis,
881+
)
882+
self.transforms.append(transform)
883+
return self
884+
758885
def to_array(
759886
self,
760887
include: str | Sequence[str] = None,

bayesflow/adapters/transforms/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .elementwise_transform import ElementwiseTransform
99
from .expand_dims import ExpandDims
1010
from .filter_transform import FilterTransform
11+
from .group import Group
1112
from .keep import Keep
1213
from .log import Log
1314
from .map_transform import MapTransform
@@ -18,11 +19,15 @@
1819
from .serializable_custom_transform import SerializableCustomTransform
1920
from .shift import Shift
2021
from .split import Split
22+
from .squeeze import Squeeze
2123
from .sqrt import Sqrt
2224
from .standardize import Standardize
2325
from .to_array import ToArray
2426
from .to_dict import ToDict
2527
from .transform import Transform
28+
from .random_subsample import RandomSubsample
29+
from .take import Take
30+
from .ungroup import Ungroup
2631
from .nan_to_num import NanToNum
2732

2833
from ...utils._docs import _add_imports_to_all

bayesflow/adapters/transforms/as_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .elementwise_transform import ElementwiseTransform
66

77

8-
@serializable
8+
@serializable("bayesflow.adapters")
99
class AsSet(ElementwiseTransform):
1010
"""The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets.
1111

bayesflow/adapters/transforms/as_time_series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .elementwise_transform import ElementwiseTransform
66

77

8-
@serializable
8+
@serializable("bayesflow.adapters")
99
class AsTimeSeries(ElementwiseTransform):
1010
"""The `.as_time_series` transform can be used to indicate that variables shall be treated as time series.
1111

bayesflow/adapters/transforms/broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .transform import Transform
77

88

9-
@serializable
9+
@serializable("bayesflow.adapters")
1010
class Broadcast(Transform):
1111
"""
1212
Broadcasts arrays or scalars to the shape of a given other array.

bayesflow/adapters/transforms/concatenate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .transform import Transform
88

99

10-
@serializable
10+
@serializable("bayesflow.adapters")
1111
class Concatenate(Transform):
1212
"""Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network.
1313

bayesflow/adapters/transforms/constrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .elementwise_transform import ElementwiseTransform
1212

1313

14-
@serializable
14+
@serializable("bayesflow.adapters")
1515
class Constrain(ElementwiseTransform):
1616
"""
1717
Constrains neural network predictions of a data variable to specified bounds.

0 commit comments

Comments
 (0)