Skip to content

Commit 759abbc

Browse files
committed
allow positional includes
1 parent 2babca5 commit 759abbc

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

bayesflow/adapters/adapter.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ def __len__(self):
233233

234234
def apply(
235235
self,
236+
include: str | Sequence[str] = None,
236237
*,
237238
forward: np.ufunc | str,
238239
inverse: np.ufunc | str = None,
239240
predicate: Predicate = None,
240-
include: str | Sequence[str] = None,
241241
exclude: str | Sequence[str] = None,
242242
**kwargs,
243243
):
@@ -388,6 +388,7 @@ def convert_dtype(
388388
exclude: str | Sequence[str] = None,
389389
):
390390
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
391+
See also :py:meth:`~bayesflow.adapters.Adapter.map_dtype`.
391392
392393
Parameters
393394
----------
@@ -525,6 +526,24 @@ def log(self, keys: str | Sequence[str], *, p1: bool = False):
525526
self.transforms.append(transform)
526527
return self
527528

529+
def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
530+
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
531+
See also :py:meth:`~bayesflow.adapters.Adapter.convert_dtype`.
532+
533+
Parameters
534+
----------
535+
keys : str or Sequence of str
536+
The names of the variables to transform.
537+
to_dtype : str
538+
Target dtype
539+
"""
540+
if isinstance(keys, str):
541+
keys = [keys]
542+
543+
transform = MapTransform({key: ConvertDType(to_dtype) for key in keys})
544+
self.transforms.append(transform)
545+
return self
546+
528547
def one_hot(self, keys: str | Sequence[str], num_classes: int):
529548
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
530549
@@ -590,9 +609,9 @@ def sqrt(self, keys: str | Sequence[str]):
590609

591610
def standardize(
592611
self,
612+
include: str | Sequence[str] = None,
593613
*,
594614
predicate: Predicate = None,
595-
include: str | Sequence[str] = None,
596615
exclude: str | Sequence[str] = None,
597616
**kwargs,
598617
):
@@ -621,9 +640,9 @@ def standardize(
621640

622641
def to_array(
623642
self,
643+
include: str | Sequence[str] = None,
624644
*,
625645
predicate: Predicate = None,
626-
include: str | Sequence[str] = None,
627646
exclude: str | Sequence[str] = None,
628647
**kwargs,
629648
):

0 commit comments

Comments
 (0)