Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@

def apply(
self,
include: str | Sequence[str] = None,
*,
forward: np.ufunc | str,
inverse: np.ufunc | str = None,
predicate: Predicate = None,
include: str | Sequence[str] = None,
exclude: str | Sequence[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -388,6 +388,7 @@
exclude: str | Sequence[str] = None,
):
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
See also :py:meth:`~bayesflow.adapters.Adapter.map_dtype`.

Parameters
----------
Expand Down Expand Up @@ -525,6 +526,24 @@
self.transforms.append(transform)
return self

def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
See also :py:meth:`~bayesflow.adapters.Adapter.convert_dtype`.

Parameters
----------
keys : str or Sequence of str
The names of the variables to transform.
to_dtype : str
Target dtype
"""
if isinstance(keys, str):
keys = [keys]

Check warning on line 541 in bayesflow/adapters/adapter.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/adapter.py#L540-L541

Added lines #L540 - L541 were not covered by tests

transform = MapTransform({key: ConvertDType(to_dtype) for key in keys})
self.transforms.append(transform)
return self

Check warning on line 545 in bayesflow/adapters/adapter.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/adapter.py#L543-L545

Added lines #L543 - L545 were not covered by tests

def one_hot(self, keys: str | Sequence[str], num_classes: int):
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.

Expand Down Expand Up @@ -590,9 +609,9 @@

def standardize(
self,
include: str | Sequence[str] = None,
*,
predicate: Predicate = None,
include: str | Sequence[str] = None,
exclude: str | Sequence[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -621,9 +640,9 @@

def to_array(
self,
include: str | Sequence[str] = None,
*,
predicate: Predicate = None,
include: str | Sequence[str] = None,
exclude: str | Sequence[str] = None,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ class FilterTransform(Transform):

def __init__(
self,
include: str | Sequence[str] = None,
*,
transform_constructor: Callable[..., ElementwiseTransform],
predicate: Predicate = None,
include: str | Sequence[str] = None,
exclude: str | Sequence[str] = None,
**kwargs,
):
Expand Down
13 changes: 13 additions & 0 deletions bayesflow/utils/empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class empty:

Check warning on line 1 in bayesflow/utils/empty.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/empty.py#L1

Added line #L1 was not covered by tests
"""
Placeholder value for arguments left empty

Usage:

def f(x=empty):
if x is empty:
# we know the user did not pass x
if x is None:
# the user could have passed None explicitly

"""
Loading