diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index ebef412bf..9edbded92 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -233,11 +233,11 @@ def __len__(self): def apply( self, + include: str | Sequence[str] = None, *, forward: np.ufunc | str, inverse: np.ufunc | str = None, predicate: Predicate = None, - include: str | Sequence[str] = None, exclude: str | Sequence[str] = None, **kwargs, ): @@ -388,6 +388,7 @@ def convert_dtype( exclude: str | Sequence[str] = None, ): """Append a :py:class:`~transforms.ConvertDType` transform to the adapter. + See also :py:meth:`~bayesflow.adapters.Adapter.map_dtype`. Parameters ---------- @@ -525,6 +526,24 @@ def log(self, keys: str | Sequence[str], *, p1: bool = False): self.transforms.append(transform) return self + def map_dtype(self, keys: str | Sequence[str], to_dtype: str): + """Append a :py:class:`~transforms.ConvertDType` transform to the adapter. + See also :py:meth:`~bayesflow.adapters.Adapter.convert_dtype`. + + Parameters + ---------- + keys : str or Sequence of str + The names of the variables to transform. + to_dtype : str + Target dtype + """ + if isinstance(keys, str): + keys = [keys] + + transform = MapTransform({key: ConvertDType(to_dtype) for key in keys}) + self.transforms.append(transform) + return self + def one_hot(self, keys: str | Sequence[str], num_classes: int): """Append a :py:class:`~transforms.OneHot` transform to the adapter. @@ -590,9 +609,9 @@ def sqrt(self, keys: str | Sequence[str]): def standardize( self, + include: str | Sequence[str] = None, *, predicate: Predicate = None, - include: str | Sequence[str] = None, exclude: str | Sequence[str] = None, **kwargs, ): @@ -621,9 +640,9 @@ def standardize( def to_array( self, + include: str | Sequence[str] = None, *, predicate: Predicate = None, - include: str | Sequence[str] = None, exclude: str | Sequence[str] = None, **kwargs, ): diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py index 65f5750c8..706dd14d3 100644 --- a/bayesflow/adapters/transforms/filter_transform.py +++ b/bayesflow/adapters/transforms/filter_transform.py @@ -29,10 +29,10 @@ class FilterTransform(Transform): def __init__( self, + include: str | Sequence[str] = None, *, transform_constructor: Callable[..., ElementwiseTransform], predicate: Predicate = None, - include: str | Sequence[str] = None, exclude: str | Sequence[str] = None, **kwargs, ): diff --git a/bayesflow/utils/empty.py b/bayesflow/utils/empty.py new file mode 100644 index 000000000..14331dc37 --- /dev/null +++ b/bayesflow/utils/empty.py @@ -0,0 +1,13 @@ +class empty: + """ + Placeholder value for arguments left empty + + Usage: + + def f(x=empty): + if x is empty: + # we know the user did not pass x + if x is None: + # the user could have passed None explicitly + + """